Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair abf4a9271e skip test 2023-09-19 12:39:40 +00:00
Dhruv Nair 0e1fb0d916 merge upstream 2023-09-19 11:27:08 +00:00
Dhruv Nair f77b7a0f27 fix tests 2023-09-19 04:32:19 +00:00
Dhruv Nair eae1371983 wip 2023-09-19 03:37:22 +00:00
130 changed files with 1090 additions and 7183 deletions
+1 -2
View File
@@ -15,7 +15,6 @@ body:
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
- 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
- type: markdown
attributes:
value: |
@@ -71,7 +70,7 @@ body:
Questions on schedulers: @patrickvonplaten and @williamberman
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman (for community pipelines, please tag the original author of the pipeline)
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman
Questions on JAX- and MPS-related things: @pcuenca
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -1,67 +0,0 @@
name: Fast tests for PRs - PEFT backend
on:
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install -U git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
-1
View File
@@ -15,7 +15,6 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
HF_HOME: /mnt/cache
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.7
- name: Install requirements
run: |
-2
View File
@@ -216,8 +216,6 @@
title: AudioLDM 2
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP Diffusion
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
@@ -1,29 +0,0 @@
# Blip Diffusion
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
The abstract from the paper is:
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
<Tip>
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## BlipDiffusionPipeline
[[autodoc]] BlipDiffusionPipeline
- all
- __call__
## BlipDiffusionControlNetPipeline
[[autodoc]] BlipDiffusionControlNetPipeline
- all
- __call__
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
Install 🤗 Diffusers for whichever deep learning library you're working with.
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
These commands will link the folder you cloned the repository to and your Python library paths.
Python will now look inside the folder you cloned to in addition to the normal library paths.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
<Tip warning={true}>
+62 -568
View File
@@ -10,597 +10,91 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Image-to-image
# Text-guided image-to-image generation
[[open-in-colab]]
Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
With 🤗 Diffusers, this is as easy as 1-2-3:
1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
Before you begin, make sure you have all the necessary libraries installed:
```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# uncomment to install the necessary libraries in Colab
#!pip install diffusers transformers ftfy accelerate
```
Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model like [`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion).
```python
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16, use_safetensors=True
).to(device)
```
Download and preprocess an initial image so you can pass it to the pipeline:
```python
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image.thumbnail((768, 768))
init_image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_8_output_0.jpeg"/>
</div>
<Tip>
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](/optimization/torch2.0#scaled-dot-product-attention).
💡 `strength` is a value between 0.0 and 1.0 that controls the amount of noise added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
</Tip>
2. Load an image to pass to the pipeline:
Define the prompt (for this checkpoint finetuned on Ghibli-style art, you need to prefix the prompt with the `ghibli style` tokens) and run the pipeline:
```py
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
```
3. Pass a prompt and image to the pipeline to generate an image:
```py
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Popular models
The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
### Stable Diffusion v1.5
Stable Diffusion v1.5 is a latent diffusion model intialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdv1.5.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
### Stable Diffusion XL (SDXL)
SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
### Kandinsky 2.2
The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
The simplest way to use Kandinsky 2.2 is:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-kandinsky.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Configure pipeline parameters
There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
### Strength
`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
- 📉 a lower `strength` value means the generated image is more similar to the initial image
The `strength` and `num_inference_steps` parameter are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = init_image
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.8).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.4.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.4</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.6.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.6</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-1.0.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 1.0</figcaption>
</div>
</div>
### Guidance scale
The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-0.1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 0.1</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-3.0.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 5.0</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-7.5.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 10.0</figcaption>
</div>
</div>
### Negative prompt
A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
# pass prompt and image to pipeline
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "jungle"</figcaption>
</div>
</div>
## Chained image-to-image pipelines
There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
### Text-to-image-to-image
Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
Start by generating an image with the text-to-image pipeline:
```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
```
Now you can pass this generated image to the image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=image).images[0]
image
```
### Image-to-image-to-image
You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generate short GIFs, restore color to an image, or restore missing areas of an image.
Start by generating an image:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
<Tip>
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
</Tip>
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
```py
pipelne = AutoPipelineForImage2Image.from_pretrained(
"ogkalu/Comic-Diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
```
Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kohbanye/pixel-art-style", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "pixelartstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
image
```
### Image-to-upscaler-to-super-resolution
Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
Start with an image-to-image pipeline:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
<Tip>
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
</Tip>
Chain it to an upscaler pipeline to increase the image resolution:
```py
upscaler = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
upscaler.enable_model_cpu_offload()
upscaler.enable_xformers_memory_efficient_attention()
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
```
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
```py
super_res = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
super_res.enable_model_cpu_offload()
super_res.enable_xformers_memory_efficient_attention()
image_3 = upscaler(prompt, image=image_2).images[0]
image_3
```
## Control image generation
Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
### Prompt weighting
Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
```py
from diffusers import AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
negative_prompt_embeds, # generated from Compel
image=init_image,
).images[0]
```
### ControlNet
ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
For example, let's condition an image with a depth map to keep the spatial information in the image.
```py
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((958, 960)) # resize to depth image dimensions
depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
```
Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
```py
from diffusers import ControlNetModel, AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
```
Now generate a new image conditioned on the depth map, initial image, and prompt:
```py
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-controlnet.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ControlNet image</figcaption>
</div>
</div>
Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image, strength=0.45, guidance_scale=10.5).images[0]
```python
prompt = "ghibli style, a fantasy landscape with castles"
generator = torch.Generator(device=device).manual_seed(1024)
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-elden-ring.png">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ghibli-castles.png"/>
</div>
## Optimize
You can also try experimenting with a different scheduler to see how that affects the output:
Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](optimization/torch2.0#scaled-dot-product-attention) or [xFormers](optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
```python
from diffusers import LMSDiscreteScheduler
```diff
+ pipeline.enable_model_cpu_offload()
+ pipeline.enable_xformers_memory_efficient_attention()
lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.scheduler = lms
generator = torch.Generator(device=device).manual_seed(1024)
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
image
```
With [`torch.compile`](optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lms-ghibli.png"/>
</div>
```py
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
Check out the Spaces below, and try generating images with different values for `strength`. You'll notice that using lower values for `strength` produces images that are more similar to the original image.
To learn more, take a look at the [Reduce memory usage](optimization/memory) and [Torch 2.0](optimization/torch2.0) guides.
Feel free to also switch the scheduler to the [`LMSDiscreteScheduler`] and see how that affects the output.
<iframe
src="https://stevhliu-ghibli-img2img.hf.space"
frameborder="0"
width="850"
height="500"
></iframe>
@@ -0,0 +1,19 @@
# What is safetensors ?
[safetensors](https://github.com/huggingface/safetensors) is a different format
from the classic `.bin` which uses Pytorch which uses pickle.
Pickle is notoriously unsafe which allow any malicious file to execute arbitrary code.
The hub itself tries to prevent issues from it, but it's not a silver bullet.
`safetensors` first and foremost goal is to make loading machine learning models *safe*
in the sense that no takeover of your computer can be done.
# Why use safetensors ?
**Safety** can be one reason, if you're attempting to use a not well known model and
you're not sure about the source of the file.
And a secondary reason, is **the speed of loading**. Safetensors can load models much faster
than regular pickle files. If you spend a lot of times switching models, this can be
a huge timesave.
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
<Tip warning={true}>
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
这些命令将连接到你克隆的版本库和你的 Python 库路径。
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`
<Tip warning={true}>
@@ -908,9 +908,6 @@ def main():
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights = snr_loss_weights + 1
loss = loss * snr_loss_weights
loss = loss.mean()
+6 -54
View File
@@ -224,30 +224,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR
snr = (alpha / sigma) ** 2
return snr
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -548,13 +524,6 @@ def parse_args(input_args=None):
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--pre_compute_text_embeddings",
action="store_true",
@@ -1292,34 +1261,17 @@ def main(args):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Compute instance loss
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps, noise_scheduler)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -875,9 +875,6 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -955,9 +955,6 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -786,9 +786,6 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -1075,9 +1075,6 @@ def main(args):
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -332,6 +332,15 @@ def parse_args(input_args=None):
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--force_snr_gamma",
action="store_true",
help=(
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--allow_tf32",
@@ -545,6 +554,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
if (
args.snr_gamma
and not args.force_snr_gamma
and (
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
)
):
raise ValueError(
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
@@ -977,11 +998,6 @@ def main(args):
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
elif noise_scheduler.config.prediction_type == "sample":
# We set the target to latents here, but the model_pred will return the noise sample prediction.
target = model_input
# We will have to subtract the noise residual from the prediction to get the target sample.
model_pred = model_pred - noise
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
@@ -992,17 +1008,9 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -1,343 +0,0 @@
"""
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
"""
import argparse
import os
import tempfile
import torch
from lavis.models import load_model_and_preprocess
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from diffusers import (
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
BLIP2_CONFIG = {
"vision_config": {
"hidden_size": 1024,
"num_hidden_layers": 23,
"num_attention_heads": 16,
"image_size": 224,
"patch_size": 14,
"intermediate_size": 4096,
"hidden_act": "quick_gelu",
},
"qformer_config": {
"cross_attention_frequency": 1,
"encoder_hidden_size": 1024,
"vocab_size": 30523,
},
"num_query_tokens": 16,
}
blip2config = Blip2Config(**BLIP2_CONFIG)
def qformer_model_from_original_config():
qformer = Blip2QFormerModel(blip2config)
return qformer
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
embeddings = {}
embeddings.update(
{
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
f"{original_embeddings_prefix}.word_embeddings.weight"
]
}
)
embeddings.update(
{
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
f"{original_embeddings_prefix}.position_embeddings.weight"
]
}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
)
return embeddings
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
proj_layer = {}
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
return proj_layer
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
attention = {}
attention.update(
{
f"{diffuser_attention_prefix}.attention.query.weight": model[
f"{original_attention_prefix}.self.query.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.attention.value.weight": model[
f"{original_attention_prefix}.self.value.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
f"{original_attention_prefix}.output.LayerNorm.weight"
]
}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
f"{original_attention_prefix}.output.LayerNorm.bias"
]
}
)
return attention
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
output_layers = {}
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
)
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
)
return output_layers
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
encoder = {}
for i in range(blip2config.qformer_config.num_hidden_layers):
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
)
)
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
)
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
]
}
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
)
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
)
)
return encoder
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder_layer = {}
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
)
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
return visual_encoder_layer
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder = {}
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
.unsqueeze(0)
.unsqueeze(0)
}
)
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.position_embedding": model[
f"{original_prefix}.positional_embedding"
].unsqueeze(0)
}
)
visual_encoder.update(
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
)
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
for i in range(blip2config.vision_config.num_hidden_layers):
visual_encoder.update(
visual_encoder_layer_from_original_checkpoint(
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
)
)
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
return visual_encoder
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
qformer_checkpoint = {}
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
qformer_checkpoint.update(
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
)
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
return qformer_checkpoint
def get_qformer(model):
print("loading qformer")
qformer = qformer_model_from_original_config()
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
print("done loading qformer")
return qformer
def load_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile(delete=False) as file:
torch.save(checkpoint, file.name)
del checkpoint
model.load_state_dict(torch.load(file.name), strict=False)
os.remove(file.name)
def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
blip_diffusion.save_pretrained(args.checkpoint_path)
def main(args):
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
save_blip_diffusion_model(model.state_dict(), args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
@@ -35,12 +35,6 @@ if __name__ == "__main__":
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--config_files",
default=None,
type=str,
help="The YAML config file corresponding to the architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
+2 -1
View File
@@ -256,7 +256,7 @@ setup(
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.8.0",
python_requires=">=3.7.0",
install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
@@ -268,6 +268,7 @@ setup(
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
-8
View File
@@ -197,8 +197,6 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"IFImg2ImgPipeline",
@@ -368,7 +366,6 @@ else:
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
"FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler",
@@ -396,7 +393,6 @@ else:
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
]
)
@@ -462,8 +458,6 @@ if TYPE_CHECKING:
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
BlipDiffusionControlNetPipeline,
BlipDiffusionPipeline,
CLIPImageProjection,
ConsistencyModelPipeline,
DanceDiffusionPipeline,
@@ -675,7 +669,6 @@ if TYPE_CHECKING:
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
@@ -694,7 +687,6 @@ if TYPE_CHECKING:
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
)
try:
+118 -205
View File
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
from collections import defaultdict
@@ -24,7 +23,6 @@ import requests
import safetensors
import torch
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
@@ -32,15 +30,11 @@ from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
)
from .utils.import_utils import BACKENDS_MAPPING
@@ -67,21 +61,6 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
@@ -1098,7 +1077,6 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
@@ -1290,7 +1268,6 @@ class LoraLoaderMixin:
state_dict = pretrained_model_name_or_path_or_dict
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
@@ -1543,35 +1520,55 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if cls.use_peft_backend:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
if network_alphas is not None:
alpha_keys = [
@@ -1581,79 +1578,56 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if cls.use_peft_backend:
from peft import LoraConfig
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@@ -1671,27 +1645,10 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if self.use_peft_backend:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if self.use_peft_backend:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
@@ -1718,7 +1675,6 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -1922,7 +1878,7 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
if "emb" in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
@@ -1934,13 +1890,6 @@ class LoraLoaderMixin:
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
@@ -2093,38 +2042,24 @@ class LoraLoaderMixin:
if fuse_unet:
self.unet.fuse_lora(lora_scale)
if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder, lora_scale)
fuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
fuse_text_encoder_lora(self.text_encoder_2)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
@@ -2146,29 +2081,18 @@ class LoraLoaderMixin:
if unfuse_unet:
self.unet.unfuse_lora()
if self.use_peft_backend:
from peft.tuners.tuner_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
@@ -2879,16 +2803,5 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
+2 -3
View File
@@ -19,7 +19,6 @@ import torch
from torch import nn
from .activations import get_activation
from .lora import LoRACompatibleLinear
def get_timestep_embedding(
@@ -167,7 +166,7 @@ class TimestepEmbedding(nn.Module):
):
super().__init__()
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -180,7 +179,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
+11 -18
View File
@@ -25,25 +25,18 @@ from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module):
@@ -42,25 +42,9 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
+17 -17
View File
@@ -303,23 +303,23 @@ class FlaxModelMixin(PushToHubMixin):
"framework": "flax",
}
# Load config if we don't provide one
if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path
model, model_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
# model args
dtype=dtype,
**kwargs,
)
# Load model
pretrained_path_with_subfolder = (
+3 -6
View File
@@ -52,7 +52,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
@@ -73,7 +72,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -193,7 +192,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
@@ -215,7 +213,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -333,7 +331,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
# there is always at least one resnet
@@ -353,7 +350,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
@@ -883,6 +883,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
+2 -65
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
@@ -116,11 +116,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
@@ -132,17 +127,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self):
block_out_channels = self.block_out_channels
@@ -183,24 +168,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == "text_time":
if self.addition_time_embed_dim is None:
raise ValueError(
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
)
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
# down
down_blocks = []
output_channel = block_out_channels[0]
@@ -215,7 +182,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
@@ -241,7 +207,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
@@ -253,7 +218,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
@@ -267,7 +231,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
@@ -306,7 +269,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample,
timesteps,
encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
return_dict: bool = True,
@@ -338,31 +300,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# additional embeddings
aug_emb = None
if self.addition_embed_type == "text_time":
if added_cond_kwargs is None:
raise ValueError(
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)
t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
+7 -22
View File
@@ -16,7 +16,7 @@ from ..utils import (
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
try:
if not is_torch_available():
@@ -67,10 +67,8 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
@@ -142,14 +140,12 @@ else:
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_xl"].extend(
[
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
)
_import_structure["stable_diffusion_xl"] = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
@@ -202,7 +198,6 @@ else:
"StableDiffusionOnnxPipeline",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
@@ -237,11 +232,6 @@ else:
"FlaxStableDiffusionPipeline",
]
)
_import_structure["stable_diffusion_xl"].extend(
[
"FlaxStableDiffusionXLPipeline",
]
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
@@ -291,9 +281,7 @@ if TYPE_CHECKING:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
@@ -445,9 +433,6 @@ if TYPE_CHECKING:
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
from .stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1,20 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
from .pipeline_blip_diffusion import BlipDiffusionPipeline
@@ -1,318 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for BLIP."""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from diffusers.utils import numpy_to_pil
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
class BlipImageProcessor(BaseImageProcessor):
r"""
Constructs a BLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
do_center_crop: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_center_crop = do_center_crop
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if do_center_crop:
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
# Follows diffusers.VaeImageProcessor.postprocess
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
@@ -1,642 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Encoder,
Blip2PreTrainedModel,
Blip2QFormerAttention,
Blip2QFormerIntermediate,
Blip2QFormerOutput,
)
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.utils import (
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__)
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
# But it doesn't support getting multimodal embeddings. So, this module can be
# replaced with a future `transformers` version supports that.
class Blip2TextEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
batch_size = embeddings.shape[0]
# repeat the query embeddings for batch size
query_embeds = query_embeds.repeat(batch_size, 1, 1)
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = embeddings.to(query_embeds.dtype)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
class Blip2VisionEmbeddings(nn.Module):
def __init__(self, config: Blip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
class Blip2QFormerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if layer_module.has_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# The layers making up the Qformer encoder
class Blip2QFormerLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = Blip2QFormerIntermediate(config)
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
self.output = Blip2QFormerOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
class ProjLayer(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
super().__init__()
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
self.dense1 = nn.Linear(in_dim, hidden_dim)
self.act_fn = QuickGELU()
self.dense2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(drop_p)
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
def forward(self, x):
x_in = x
x = self.LayerNorm(x)
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
return x
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
class Blip2VisionModel(Blip2PreTrainedModel):
main_input_name = "pixel_values"
config_class = Blip2VisionConfig
def __init__(self, config: Blip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Blip2VisionEmbeddings(config)
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = Blip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layernorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
# Qformer model, used to get multimodal embeddings from the text and image inputs
class Blip2QFormerModel(Blip2PreTrainedModel):
"""
Querying Transformer (Q-Former), used in BLIP-2.
"""
def __init__(self, config: Blip2Config):
super().__init__(config)
self.config = config
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
self.visual_encoder = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
if not hasattr(config, "tokenizer") or config.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
else:
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.proj_layer = ProjLayer(
in_dim=config.qformer_config.hidden_size,
out_dim=config.qformer_config.hidden_size,
hidden_dim=config.qformer_config.hidden_size * 4,
drop_p=0.1,
eps=1e-12,
)
self.encoder = Blip2QFormerEncoder(config.qformer_config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int],
device: torch.device,
has_query: bool = False,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
text_input=None,
image_input=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
text = text.to(self.device)
input_ids = text.input_ids
batch_size = input_ids.shape[0]
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
)
query_length = self.query_tokens.shape[1]
embedding_output = self.embeddings(
input_ids=input_ids,
query_embeds=self.query_tokens,
past_key_values_length=past_key_values_length,
)
# embedding_output = self.layernorm(query_embeds)
# embedding_output = self.dropout(embedding_output)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
encoder_hidden_states = image_embeds_frozen
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if isinstance(encoder_attention_mask, list):
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = sequence_output[:, 0, :]
if not return_dict:
return self.proj_layer(sequence_output[:, :query_length, :])
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@@ -1,212 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import (
CLIPEncoder,
_expand_mask,
)
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
class ContextCLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = ContextCLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
ctx_embeddings: torch.Tensor = None,
ctx_begin_pos: list = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return self.text_model(
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class ContextCLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = ContextCLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
)
bsz, seq_len = input_shape
if ctx_embeddings is not None:
seq_len += ctx_embeddings.size(1)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
input_ids.to(torch.int).argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class ContextCLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if ctx_embeddings is None:
ctx_len = 0
else:
ctx_len = ctx_embeddings.shape[1]
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
# for each input embeddings, add the ctx embeddings at the correct position
input_embeds_ctx = []
bsz = inputs_embeds.shape[0]
if ctx_embeddings is not None:
for i in range(bsz):
cbp = ctx_begin_pos[i]
prefix = inputs_embeds[i, :cbp]
# remove the special token embedding
suffix = inputs_embeds[i, cbp:]
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
@@ -1,339 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionPipeline
>>> from diffusers.utils import load_image
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
... ).to("cuda")
>>> cond_subject = "dog"
>>> tgt_subject = "dog"
>>> text_prompt_input = "swimming underwater"
>>> cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 25
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt_input,
... cond_image,
... cond_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionPipeline(DiffusionPipeline):
"""
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
+77 -79
View File
@@ -1,79 +1,77 @@
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1,405 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
>>> from diffusers.utils import load_image
>>> from controlnet_aux import CannyDetector
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
... ).to("cuda")
>>> style_subject = "flower"
>>> tgt_subject = "teapot"
>>> text_prompt = "on a marble table"
>>> cldm_cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
... ).resize((512, 512))
>>> canny = CannyDetector()
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
>>> style_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 50
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt,
... style_image,
... cldm_cond_image,
... style_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
controlnet ([`ControlNetModel`]):
ControlNet model to get the conditioning image embedding.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
controlnet: ControlNetModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
controlnet=controlnet,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.image_processor.preprocess(
image,
size={"width": width, "height": height},
do_rescale=True,
do_center_crop=False,
do_normalize=False,
return_tensors="pt",
)["pixel_values"].to(self.device)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
condtioning_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
condtioning_image (`PIL.Image.Image`):
The conditioning canny edge image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
seed (`int`, *optional*, defaults to 42):
The seed to use for random generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
cond_image = self.prepare_control_image(
image=condtioning_image,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=1,
device=self.device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=cond_image,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -315,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
+2 -21
View File
@@ -394,29 +394,10 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
init_kwargs = {}
# inference_params
params = {}
@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -297,7 +297,7 @@ class StableDiffusionInpaintPipelineLegacy(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -211,7 +211,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -609,9 +609,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler
if "generator" in inspect.signature(self.sampler).parameters:
sampler_kwargs["generator"] = generator
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
if not output_type == "latent":
@@ -272,7 +272,7 @@ class StableDiffusionLDM3DPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -244,7 +244,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -221,7 +221,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -256,7 +256,7 @@ class StableDiffusionParadigmsPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -244,7 +244,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -240,7 +240,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -346,7 +346,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -296,7 +296,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -4,18 +4,14 @@ from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]}
if is_transformers_available() and is_flax_available():
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"])
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
@@ -29,12 +25,6 @@ else:
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
_additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]
if TYPE_CHECKING:
try:
@@ -48,17 +38,6 @@ if TYPE_CHECKING:
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_objects import *
else:
from .pipeline_flax_stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
)
from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
else:
import sys
@@ -71,5 +50,3 @@ else:
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
@@ -1,306 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Dict, List, Optional, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPTokenizer, FlaxCLIPTextModel
from diffusers.utils import logging
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionXLPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
def __init__(
self,
text_encoder: FlaxCLIPTextModel,
text_encoder_2: FlaxCLIPTextModel,
vae: FlaxAutoencoderKL,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
self.dtype = dtype
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# Assume we have the two encoders
inputs = []
for tokenizer in [self.tokenizer, self.tokenizer_2]:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
inputs.append(text_inputs.input_ids)
inputs = jnp.stack(inputs, axis=1)
return inputs
def __call__(
self,
prompt_ids: jax.Array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None,
width: Optional[int] = None,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
output_type: str = None,
jit: bool = False,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(guidance_scale, float) and jit:
# Convert to a tensor so each device gets a copy.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
guidance_scale = guidance_scale[:, None]
return_latents = output_type == "latent"
if jit:
images = _p_generate(
self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
else:
images = self._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
if not return_dict:
return (images,)
return FlaxStableDiffusionXLPipelineOutput(images=images)
def get_embeddings(self, prompt_ids: jnp.array, params):
# We assume we have the two encoders
# bs, encoder_input, seq_length
te_1_inputs = prompt_ids[:, 0, :]
te_2_inputs = prompt_ids[:, 1, :]
prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True)
prompt_embeds = prompt_embeds["hidden_states"][-2]
prompt_embeds_2_out = self.text_encoder_2(
te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True
)
prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
text_embeds = prompt_embeds_2_out["text_embeds"]
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
return prompt_embeds, text_embeds
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype)
return add_time_ids
def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
return_latents=False,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# Encode input prompt
prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params)
# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]
if neg_prompt_ids is None:
neg_prompt_ids = self.prepare_inputs([""] * batch_size)
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
add_time_ids = self._get_add_time_ids(
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
)
prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
# Create random latents
latents_shape = (
batch_size,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma
# Prepare scheduler state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# Denoising loop
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
if DEBUG:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
if return_latents:
return latents
# Decode latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None),
static_broadcasted_argnums=(0, 4, 5, 6, 10),
)
def _p_generate(
pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
):
return pipe._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
@@ -4,11 +4,7 @@ from typing import List, Union
import numpy as np
import PIL
from ...utils import (
BaseOutput,
is_flax_available,
is_transformers_available,
)
from ...utils import BaseOutput
@dataclass
@@ -23,19 +19,3 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Flax Stable Diffusion XL pipelines.
Args:
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
"""
images: np.ndarray
@@ -264,8 +264,8 @@ class StableDiffusionXLPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -271,8 +271,8 @@ class StableDiffusionXLImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -722,16 +722,16 @@ class StableDiffusionXLImg2ImgPipeline(
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -420,8 +420,8 @@ class StableDiffusionXLInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -272,8 +272,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -296,7 +296,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -288,8 +288,8 @@ class StableDiffusionXLAdapterPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -787,16 +787,8 @@ class StableDiffusionXLAdapterPipeline(
height, width = self._default_height_width(height, width, image)
device = self._execution_device
if isinstance(self.adapter, MultiAdapter):
adapter_input = []
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
for one_image in image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(image, height, width)
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -873,14 +865,10 @@ class StableDiffusionXLAdapterPipeline(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings & adapter features
if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
adapter_input = adapter_input.type(latents.dtype)
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
@@ -228,7 +228,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -290,7 +290,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1094,6 +1094,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
-1
View File
@@ -76,7 +76,6 @@ else:
_import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
@@ -22,7 +22,6 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@@ -187,14 +186,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
@@ -234,17 +225,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
@@ -253,9 +245,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
@@ -291,57 +280,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def convert_model_output(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DEIS algorithm needs.
@@ -358,26 +298,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -389,6 +316,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = self._threshold_sample(x0_pred)
if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
return (sample - alpha_t * x0_pred) / sigma_t
else:
raise NotImplementedError("only support log-rho multistep deis now")
@@ -396,9 +324,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def deis_first_order_update(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the first-order DEIS (equivalent to DDIM).
@@ -417,33 +345,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s
if self.config.algorithm_type == "deis":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
@@ -454,9 +358,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the second-order multistep DEIS.
@@ -464,6 +368,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -471,38 +379,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
@@ -523,9 +403,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the third-order multistep DEIS.
@@ -533,6 +413,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process.
@@ -540,47 +424,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
rho_t, rho_s0, rho_s1, rho_s2 = (
sigma_t / alpha_t,
sigma_s0 / alpha_s0,
sigma_s1 / alpha_s1,
sigma_s2 / alpha_s2,
simga_s2 / alpha_s2,
)
if self.config.algorithm_type == "deis":
@@ -608,25 +460,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError("only support log-rho multistep deis now")
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step(
self,
model_output: torch.FloatTensor,
@@ -659,34 +492,42 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.deis_first_order_update(model_output, sample=sample)
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_deis_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
else:
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_deis_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
@@ -707,31 +548,28 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
@@ -21,7 +21,6 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@@ -204,14 +203,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
@@ -251,20 +242,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
@@ -273,9 +264,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
@@ -335,12 +323,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
@@ -356,11 +338,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
def convert_model_output(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -377,6 +355,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -384,18 +364,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
@@ -403,14 +371,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -432,12 +398,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
@@ -446,8 +410,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
@@ -457,10 +420,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
@@ -468,6 +431,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -475,33 +442,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
@@ -526,10 +469,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the second-order multistep DPMSolver.
@@ -537,6 +480,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -544,43 +491,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
@@ -649,9 +564,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the third-order multistep DPMSolver.
@@ -659,6 +574,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process.
@@ -666,47 +585,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
@@ -731,25 +619,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step(
self,
model_output: torch.FloatTensor,
@@ -785,17 +654,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
@@ -808,18 +682,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
prev_sample = self.dpm_solver_first_order_update(
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
@@ -840,30 +719,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
@@ -21,7 +21,6 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@@ -204,16 +203,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.use_karras_sigmas = use_karras_sigmas
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -253,19 +244,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = timesteps.copy().astype(np.int64)
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
@@ -274,7 +257,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
@@ -283,9 +266,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
@@ -345,13 +325,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
@@ -368,11 +341,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -389,6 +358,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -396,18 +367,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
@@ -415,14 +374,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -444,12 +401,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
@@ -458,22 +413,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
@@ -481,6 +434,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -488,62 +445,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
elif "sde" in self.config.algorithm_type:
raise NotImplementedError(
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the second-order multistep DPMSolver.
@@ -551,6 +473,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -558,43 +484,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
@@ -626,47 +520,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif "sde" in self.config.algorithm_type:
raise NotImplementedError(
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the third-order multistep DPMSolver.
@@ -674,6 +540,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process.
@@ -681,47 +551,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
@@ -746,27 +585,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
def step(
self,
model_output: torch.FloatTensor,
@@ -786,8 +604,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
@@ -802,17 +618,24 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = (
self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
)
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
@@ -825,18 +648,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
prev_sample = self.dpm_solver_first_order_update(
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
@@ -858,31 +686,28 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
@@ -21,7 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate, logging
from ..utils import logging
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
def get_order_list(self, num_inference_steps: int) -> List[int]:
"""
@@ -233,13 +232,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
orders = [1] * steps
return orders
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -264,18 +256,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [None] * self.config.solver_order
self.sample = None
@@ -287,9 +274,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.order_list = self.get_order_list(num_inference_steps)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
@@ -349,13 +333,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
@@ -371,11 +348,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
def convert_model_output(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -392,6 +365,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
@@ -399,32 +374,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -444,13 +405,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output = model_output[:, :3]
return model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
@@ -462,9 +421,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
@@ -483,31 +442,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
@@ -518,9 +455,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
@@ -540,42 +477,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
@@ -612,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
@@ -634,47 +540,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
)
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m2
@@ -716,10 +591,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
order: int = None,
**kwargs,
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the singlestep DPMSolver.
@@ -740,60 +615,19 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(" missing `order` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_second_order_update(
model_output_list, timestep_list, prev_timestep, sample
)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_third_order_update(
model_output_list, timestep_list, prev_timestep, sample
)
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step(
self,
model_output: torch.FloatTensor,
@@ -826,15 +660,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
model_output = self.convert_model_output(model_output, sample=sample)
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
order = self.order_list[self.step_index]
order = self.order_list[step_index]
# For img2img denoising might start with order>1 which is not possible
# In this case make sure that the first two steps are both order=1
@@ -845,10 +685,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if order == 1:
self.sample = sample
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
# upon completion increase step index by one
self._step_index += 1
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, timestep_list, prev_timestep, self.sample, order
)
if not return_dict:
return (prev_sample,)
@@ -870,31 +710,28 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
@@ -1,265 +0,0 @@
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
@flax.struct.dataclass
class EulerDiscreteSchedulerState:
common: CommonSchedulerState
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
sigmas: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
@dataclass
class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
state: EulerDiscreteSchedulerState
class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32,
):
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
init_noise_sigma = sigmas.max()
else:
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
return EulerDiscreteSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
timestep (`int`):
current discrete timestep in the diffusion chain.
Returns:
`jnp.ndarray`: scaled input sample
"""
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> EulerDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
if self.config.timestep_spacing == "linspace":
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += 1
else:
raise ValueError(
f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}"
)
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
init_noise_sigma = sigmas.max()
else:
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
return state.replace(
timesteps=timesteps,
sigmas=sigmas,
num_inference_steps=num_inference_steps,
init_noise_sigma=init_noise_sigma,
)
def step(
self,
state: EulerDiscreteSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class
Returns:
[`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
# dt = sigma_down - sigma
dt = state.sigmas[step_index + 1] - sigma
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample, state)
return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
state: EulerDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten()
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
@@ -89,9 +89,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -116,7 +113,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
@@ -247,15 +243,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -279,13 +269,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
timesteps = torch.from_numpy(timesteps).to(device)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -298,44 +282,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = np.log(sigma)
log_sigma = sigma.log()
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
dists = log_sigma - self.log_sigmas[:, None]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
w = w.clamp(0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
t = t.view(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property
def state_in_first_order(self):
return self.sample is None
@@ -88,9 +88,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -115,7 +112,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
@@ -247,14 +243,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -269,12 +260,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -287,6 +273,29 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
# get distribution
dists = log_sigma - self.log_sigmas[:, None]
# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
return t
@property
def state_in_first_order(self):
return self.sample is None
@@ -309,44 +318,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = step_index.item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
@@ -22,16 +22,10 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -44,30 +38,19 @@ def betas_for_alpha_bar(
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -198,14 +181,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
@@ -245,17 +220,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
@@ -267,9 +243,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
@@ -329,13 +302,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
@@ -351,11 +317,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
def convert_model_output(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
@@ -372,28 +334,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -409,9 +357,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
@@ -423,10 +373,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_p_bh_update(
self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
order: int = None,
**kwargs,
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
@@ -445,26 +394,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(" missing `order` as a required keyward argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
s0, t = self.timestep_list[-1], prev_timestep
m0 = model_output_list[-1]
x = sample
@@ -472,12 +405,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = sample.device
@@ -485,10 +415,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
@@ -552,11 +481,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_c_bh_update(
self,
this_model_output: torch.FloatTensor,
*args,
last_sample: torch.FloatTensor = None,
this_sample: torch.FloatTensor = None,
order: int = None,
**kwargs,
this_timestep: int,
last_sample: torch.FloatTensor,
this_sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniC (B(h) version).
@@ -577,42 +505,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError(" missing`last_sample` as a required keyward argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError(" missing`this_sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(" missing`order` as a required keyward argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = timestep_list[-1], this_timestep
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = this_sample.device
@@ -620,10 +524,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
@@ -686,25 +589,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype)
return x_t
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step(
self,
model_output: torch.FloatTensor,
@@ -732,27 +616,37 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
use_corrector = (
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, sample=sample)
model_output_convert = self.convert_model_output(model_output, timestep, sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
this_timestep=timestep,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
# now prepare to run the predictor
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
@@ -761,7 +655,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timestep_list[-1] = timestep
if self.config.lower_order_final:
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
else:
this_order = self.config.solver_order
@@ -771,6 +665,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
prev_timestep=prev_timestep,
sample=sample,
order=self.this_order,
)
@@ -778,9 +673,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
@@ -801,31 +693,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
@@ -37,7 +37,6 @@ class FlaxKarrasDiffusionSchedulers(Enum):
FlaxPNDMScheduler = 3
FlaxLMSDiscreteScheduler = 4
FlaxDPMSolverMultistepScheduler = 5
FlaxEulerDiscreteScheduler = 6
@dataclass
-3
View File
@@ -67,7 +67,6 @@ from .import_utils import (
is_note_seq_available,
is_omegaconf_available,
is_onnx_available,
is_peft_available,
is_scipy_available,
is_tensorboard_available,
is_torch_available,
@@ -83,9 +82,7 @@ from .import_utils import (
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import recurse_remove_peft_layers
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
logger = get_logger(__name__)
@@ -60,18 +60,3 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
class FlaxStableDiffusionXLPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
-15
View File
@@ -122,21 +122,6 @@ class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
requires_backends(cls, ["flax"])
class FlaxEulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
-30
View File
@@ -315,36 +315,6 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlipDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch"]
-12
View File
@@ -267,14 +267,6 @@ except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
_peft_available = importlib.util.find_spec("peft") is not None
try:
_peft_version = importlib_metadata.version("peft")
logger.debug(f"Successfully imported peft version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
_peft_available = False
def is_torch_available():
return _torch_available
@@ -359,10 +351,6 @@ def is_invisible_watermark_available():
return _invisible_watermark_available
def is_peft_available():
return _peft_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
-71
View File
@@ -1,71 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PEFT utilities: Utilities related to peft library
"""
from .import_utils import is_torch_available
if is_torch_available():
import torch
def recurse_remove_peft_layers(model):
r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
"""
from peft.tuners.lora import LoraLayer
for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)
module_replaced = False
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
module.bias,
).to(module.weight.device)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = True
if module_replaced:
setattr(model, name, new_module)
del module
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model
-184
View File
@@ -1,184 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
State dict utilities: utility methods for converting state dicts easily
"""
import enum
class StateDictType(enum.Enum):
"""
The mode to use when converting state dicts.
"""
DIFFUSERS_OLD = "diffusers_old"
# KOHYA_SS = "kohya_ss" # TODO: implement this
PEFT = "peft"
DIFFUSERS = "diffusers"
DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
}
DIFFUSERS_OLD_TO_PEFT = {
".to_q_lora.up": ".q_proj.lora_B",
".to_q_lora.down": ".q_proj.lora_A",
".to_k_lora.up": ".k_proj.lora_B",
".to_k_lora.down": ".k_proj.lora_A",
".to_v_lora.up": ".v_proj.lora_B",
".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A",
}
PEFT_TO_DIFFUSERS = {
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
}
DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}
PEFT_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
}
DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
def convert_state_dict(state_dict, mapping):
r"""
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
mapping (`dict[str, str]`):
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
- key: the pattern to replace
- value: the pattern to replace with
Returns:
converted_state_dict (`dict`)
The converted state dict.
"""
converted_state_dict = {}
for k, v in state_dict.items():
for pattern in mapping.keys():
if pattern in k:
new_pattern = mapping[pattern]
k = k.replace(pattern, new_pattern)
break
converted_state_dict[k] = v
return converted_state_dict
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
"""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any("lora_linear_layer" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
return the state dict as is.
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
elif any("lora_linear_layer" in k for k in state_dict.keys()):
# nothing to do
return state_dict
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
-27
View File
@@ -1,4 +1,3 @@
import importlib
import inspect
import io
import logging
@@ -30,11 +29,9 @@ from .import_utils import (
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_peft_available,
is_torch_available,
is_torch_version,
is_torchsde_available,
is_transformers_available,
)
from .logging import get_logger
@@ -43,15 +40,6 @@ global_rng = random.Random()
logger = get_logger(__name__)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
if is_torch_available():
import torch
@@ -248,21 +236,6 @@ def require_torchsde(test_case):
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
def require_peft_backend(test_case):
"""
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
transformers.
"""
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
"""
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
if isinstance(arry, str):
# local_path = "/home/patrick_huggingface_co/"
@@ -52,15 +52,7 @@ from diffusers.models.attention_processor import (
XFormersAttnProcessor,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
deprecate_after_peft_backend,
floats_tensor,
load_image,
nightly,
require_torch_gpu,
slow,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, slow, torch_device
def create_lora_layers(model, mock_weights: bool = True):
@@ -189,7 +181,6 @@ def state_dicts_almost_equal(sd1, sd2):
return models_are_equal
@deprecate_after_peft_backend
class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
@@ -782,7 +773,6 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
assert np.abs(image_slice - image_slice_2).max() > 1e-2
@deprecate_after_peft_backend
class SDXLLoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
@@ -1886,25 +1876,6 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_lycoris(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
).to(torch_device)
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
lora_filename = "edgLycorisMugler-light.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)

Some files were not shown because too many files have changed in this diff Show More