Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| abf4a9271e | |||
| 0e1fb0d916 | |||
| f77b7a0f27 | |||
| eae1371983 |
@@ -15,7 +15,6 @@ body:
|
||||
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
|
||||
- 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
|
||||
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
|
||||
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
@@ -71,7 +70,7 @@ body:
|
||||
|
||||
Questions on schedulers: @patrickvonplaten and @williamberman
|
||||
|
||||
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman (for community pipelines, please tag the original author of the pipeline)
|
||||
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman
|
||||
|
||||
Questions on JAX- and MPS-related things: @pcuenca
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
name: Fast tests for PRs - PEFT backend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
run_fast_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: LoRA
|
||||
framework: lora
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_lora
|
||||
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
|
||||
if: ${{ matrix.config.framework == 'lora' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/test_lora_layers_peft.py
|
||||
@@ -15,7 +15,6 @@ concurrency:
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 4
|
||||
HF_HOME: /mnt/cache
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: 3.7
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
|
||||
@@ -216,8 +216,6 @@
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP Diffusion
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Blip Diffusion
|
||||
|
||||
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
|
||||
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
|
||||
|
||||
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
|
||||
|
||||
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## BlipDiffusionPipeline
|
||||
[[autodoc]] BlipDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## BlipDiffusionControlNetPipeline
|
||||
[[autodoc]] BlipDiffusionControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Install 🤗 Diffusers for whichever deep learning library you're working with.
|
||||
|
||||
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
Python will now look inside the folder you cloned to in addition to the normal library paths.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -10,597 +10,91 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Image-to-image
|
||||
# Text-guided image-to-image generation
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
|
||||
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
|
||||
With 🤗 Diffusers, this is as easy as 1-2-3:
|
||||
|
||||
1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install diffusers transformers ftfy accelerate
|
||||
```
|
||||
|
||||
Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model like [`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion).
|
||||
|
||||
```python
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to(device)
|
||||
```
|
||||
|
||||
Download and preprocess an initial image so you can pass it to the pipeline:
|
||||
|
||||
```python
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image.thumbnail((768, 768))
|
||||
init_image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_8_output_0.jpeg"/>
|
||||
</div>
|
||||
|
||||
<Tip>
|
||||
|
||||
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](/optimization/torch2.0#scaled-dot-product-attention).
|
||||
💡 `strength` is a value between 0.0 and 1.0 that controls the amount of noise added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
||||
|
||||
</Tip>
|
||||
|
||||
2. Load an image to pass to the pipeline:
|
||||
Define the prompt (for this checkpoint finetuned on Ghibli-style art, you need to prefix the prompt with the `ghibli style` tokens) and run the pipeline:
|
||||
|
||||
```py
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
|
||||
```
|
||||
|
||||
3. Pass a prompt and image to the pipeline to generate an image:
|
||||
|
||||
```py
|
||||
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
||||
image = pipeline(prompt, image=init_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Popular models
|
||||
|
||||
The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
|
||||
|
||||
### Stable Diffusion v1.5
|
||||
|
||||
Stable Diffusion v1.5 is a latent diffusion model intialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdv1.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Stable Diffusion XL (SDXL)
|
||||
|
||||
SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image, strength=).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Kandinsky 2.2
|
||||
|
||||
The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
|
||||
|
||||
The simplest way to use Kandinsky 2.2 is:
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-kandinsky.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Configure pipeline parameters
|
||||
|
||||
There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
|
||||
|
||||
### Strength
|
||||
|
||||
`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
|
||||
|
||||
- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
|
||||
- 📉 a lower `strength` value means the generated image is more similar to the initial image
|
||||
|
||||
The `strength` and `num_inference_steps` parameter are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = init_image
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image, strength=0.8).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.4.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.4</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.6.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.6</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-1.0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 1.0</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Guidance scale
|
||||
|
||||
The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
|
||||
|
||||
You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-0.1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 0.1</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-3.0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 5.0</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-7.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 10.0</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Negative prompt
|
||||
|
||||
A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "jungle"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Chained image-to-image pipelines
|
||||
|
||||
There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
|
||||
|
||||
### Text-to-image-to-image
|
||||
|
||||
Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
|
||||
|
||||
Start by generating an image with the text-to-image pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
|
||||
```
|
||||
|
||||
Now you can pass this generated image to the image-to-image pipeline:
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Image-to-image-to-image
|
||||
|
||||
You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generate short GIFs, restore color to an image, or restore missing areas of an image.
|
||||
|
||||
Start by generating an image:
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
|
||||
|
||||
</Tip>
|
||||
|
||||
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
|
||||
|
||||
```py
|
||||
pipelne = AutoPipelineForImage2Image.from_pretrained(
|
||||
"ogkalu/Comic-Diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
|
||||
image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"kohbanye/pixel-art-style", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# need to include the token "pixelartstyle" in the prompt to use this checkpoint
|
||||
image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Image-to-upscaler-to-super-resolution
|
||||
|
||||
Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
|
||||
|
||||
Start with an image-to-image pipeline:
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
|
||||
|
||||
</Tip>
|
||||
|
||||
Chain it to an upscaler pipeline to increase the image resolution:
|
||||
|
||||
```py
|
||||
upscaler = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
upscaler.enable_model_cpu_offload()
|
||||
upscaler.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
|
||||
|
||||
```py
|
||||
super_res = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
super_res.enable_model_cpu_offload()
|
||||
super_res.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image_3 = upscaler(prompt, image=image_2).images[0]
|
||||
image_3
|
||||
```
|
||||
|
||||
## Control image generation
|
||||
|
||||
Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
|
||||
|
||||
### Prompt weighting
|
||||
|
||||
Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
|
||||
|
||||
[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
|
||||
negative_prompt_embeds, # generated from Compel
|
||||
image=init_image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
|
||||
|
||||
For example, let's condition an image with a depth map to keep the spatial information in the image.
|
||||
|
||||
```py
|
||||
# prepare image
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((958, 960)) # resize to depth image dimensions
|
||||
depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
|
||||
```
|
||||
|
||||
Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
Now generate a new image conditioned on the depth map, initial image, and prompt:
|
||||
|
||||
```py
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-controlnet.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ControlNet image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
|
||||
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
|
||||
|
||||
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image, strength=0.45, guidance_scale=10.5).images[0]
|
||||
```python
|
||||
prompt = "ghibli style, a fantasy landscape with castles"
|
||||
generator = torch.Generator(device=device).manual_seed(1024)
|
||||
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-elden-ring.png">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ghibli-castles.png"/>
|
||||
</div>
|
||||
|
||||
## Optimize
|
||||
You can also try experimenting with a different scheduler to see how that affects the output:
|
||||
|
||||
Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](optimization/torch2.0#scaled-dot-product-attention) or [xFormers](optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
```diff
|
||||
+ pipeline.enable_model_cpu_offload()
|
||||
+ pipeline.enable_xformers_memory_efficient_attention()
|
||||
lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.scheduler = lms
|
||||
generator = torch.Generator(device=device).manual_seed(1024)
|
||||
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
With [`torch.compile`](optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lms-ghibli.png"/>
|
||||
</div>
|
||||
|
||||
```py
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
Check out the Spaces below, and try generating images with different values for `strength`. You'll notice that using lower values for `strength` produces images that are more similar to the original image.
|
||||
|
||||
To learn more, take a look at the [Reduce memory usage](optimization/memory) and [Torch 2.0](optimization/torch2.0) guides.
|
||||
Feel free to also switch the scheduler to the [`LMSDiscreteScheduler`] and see how that affects the output.
|
||||
|
||||
<iframe
|
||||
src="https://stevhliu-ghibli-img2img.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="500"
|
||||
></iframe>
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# What is safetensors ?
|
||||
|
||||
[safetensors](https://github.com/huggingface/safetensors) is a different format
|
||||
from the classic `.bin` which uses Pytorch which uses pickle.
|
||||
|
||||
Pickle is notoriously unsafe which allow any malicious file to execute arbitrary code.
|
||||
The hub itself tries to prevent issues from it, but it's not a silver bullet.
|
||||
|
||||
`safetensors` first and foremost goal is to make loading machine learning models *safe*
|
||||
in the sense that no takeover of your computer can be done.
|
||||
|
||||
# Why use safetensors ?
|
||||
|
||||
**Safety** can be one reason, if you're attempting to use a not well known model and
|
||||
you're not sure about the source of the file.
|
||||
|
||||
And a secondary reason, is **the speed of loading**. Safetensors can load models much faster
|
||||
than regular pickle files. If you spend a lot of times switching models, this can be
|
||||
a huge timesave.
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
|
||||
|
||||
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
|
||||
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
|
||||
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
|
||||
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
|
||||
|
||||
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
|
||||
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
|
||||
|
||||
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
|
||||
|
||||
这些命令将连接到你克隆的版本库和你的 Python 库路径。
|
||||
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -908,9 +908,6 @@ def main():
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
snr_loss_weights = snr_loss_weights + 1
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -224,30 +224,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -548,13 +524,6 @@ def parse_args(input_args=None):
|
||||
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre_compute_text_embeddings",
|
||||
action="store_true",
|
||||
@@ -1292,34 +1261,17 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Compute instance loss
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -875,9 +875,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -955,9 +955,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -786,9 +786,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1075,9 +1075,6 @@ def main(args):
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -332,6 +332,15 @@ def parse_args(input_args=None):
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_snr_gamma",
|
||||
action="store_true",
|
||||
help=(
|
||||
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
|
||||
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
|
||||
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
@@ -545,6 +554,18 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
# Check for terminal SNR in combination with SNR Gamma
|
||||
if (
|
||||
args.snr_gamma
|
||||
and not args.force_snr_gamma
|
||||
and (
|
||||
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
|
||||
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
|
||||
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
|
||||
)
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
@@ -977,11 +998,6 @@ def main(args):
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
elif noise_scheduler.config.prediction_type == "sample":
|
||||
# We set the target to latents here, but the model_pred will return the noise sample prediction.
|
||||
target = model_input
|
||||
# We will have to subtract the noise residual from the prediction to get the target sample.
|
||||
model_pred = model_pred - noise
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -992,17 +1008,9 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1,343 +0,0 @@
|
||||
"""
|
||||
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines import BlipDiffusionPipeline
|
||||
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
BLIP2_CONFIG = {
|
||||
"vision_config": {
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 23,
|
||||
"num_attention_heads": 16,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_act": "quick_gelu",
|
||||
},
|
||||
"qformer_config": {
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 1024,
|
||||
"vocab_size": 30523,
|
||||
},
|
||||
"num_query_tokens": 16,
|
||||
}
|
||||
blip2config = Blip2Config(**BLIP2_CONFIG)
|
||||
|
||||
|
||||
def qformer_model_from_original_config():
|
||||
qformer = Blip2QFormerModel(blip2config)
|
||||
return qformer
|
||||
|
||||
|
||||
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
|
||||
embeddings = {}
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.word_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.position_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
|
||||
proj_layer = {}
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
|
||||
return proj_layer
|
||||
|
||||
|
||||
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
|
||||
attention = {}
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.query.weight": model[
|
||||
f"{original_attention_prefix}.self.query.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.value.weight": model[
|
||||
f"{original_attention_prefix}.self.value.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
return attention
|
||||
|
||||
|
||||
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
|
||||
output_layers = {}
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return output_layers
|
||||
|
||||
|
||||
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
|
||||
encoder = {}
|
||||
for i in range(blip2config.qformer_config.num_hidden_layers):
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
|
||||
)
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
|
||||
)
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder_layer = {}
|
||||
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
|
||||
)
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
|
||||
|
||||
return visual_encoder_layer
|
||||
|
||||
|
||||
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder = {}
|
||||
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.position_embedding": model[
|
||||
f"{original_prefix}.positional_embedding"
|
||||
].unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
|
||||
)
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
|
||||
|
||||
for i in range(blip2config.vision_config.num_hidden_layers):
|
||||
visual_encoder.update(
|
||||
visual_encoder_layer_from_original_checkpoint(
|
||||
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
|
||||
)
|
||||
)
|
||||
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
|
||||
|
||||
return visual_encoder
|
||||
|
||||
|
||||
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
|
||||
qformer_checkpoint = {}
|
||||
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
|
||||
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
|
||||
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
|
||||
qformer_checkpoint.update(
|
||||
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
|
||||
)
|
||||
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
|
||||
return qformer_checkpoint
|
||||
|
||||
|
||||
def get_qformer(model):
|
||||
print("loading qformer")
|
||||
|
||||
qformer = qformer_model_from_original_config()
|
||||
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
|
||||
|
||||
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
|
||||
|
||||
print("done loading qformer")
|
||||
return qformer
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
model.load_state_dict(torch.load(file.name), strict=False)
|
||||
|
||||
os.remove(file.name)
|
||||
|
||||
|
||||
def save_blip_diffusion_model(model, args):
|
||||
qformer = get_qformer(model)
|
||||
qformer.eval()
|
||||
|
||||
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
vae.eval()
|
||||
text_encoder.eval()
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
|
||||
image_processor = BlipImageProcessor()
|
||||
blip_diffusion = BlipDiffusionPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
blip_diffusion.save_pretrained(args.checkpoint_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
|
||||
save_blip_diffusion_model(model.state_dict(), args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -35,12 +35,6 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_files",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_in_channels",
|
||||
default=None,
|
||||
|
||||
@@ -256,7 +256,7 @@ setup(
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8.0",
|
||||
python_requires=">=3.7.0",
|
||||
install_requires=list(install_requires),
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||
@@ -268,6 +268,7 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
|
||||
@@ -197,8 +197,6 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -368,7 +366,6 @@ else:
|
||||
"FlaxDDIMScheduler",
|
||||
"FlaxDDPMScheduler",
|
||||
"FlaxDPMSolverMultistepScheduler",
|
||||
"FlaxEulerDiscreteScheduler",
|
||||
"FlaxKarrasVeScheduler",
|
||||
"FlaxLMSDiscreteScheduler",
|
||||
"FlaxPNDMScheduler",
|
||||
@@ -396,7 +393,6 @@ else:
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
"FlaxStableDiffusionXLPipeline",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -462,8 +458,6 @@ if TYPE_CHECKING:
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
BlipDiffusionControlNetPipeline,
|
||||
BlipDiffusionPipeline,
|
||||
CLIPImageProjection,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
@@ -675,7 +669,6 @@ if TYPE_CHECKING:
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxEulerDiscreteScheduler,
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
@@ -694,7 +687,6 @@ if TYPE_CHECKING:
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
FlaxStableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
+118
-205
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
@@ -24,7 +23,6 @@ import requests
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
@@ -32,15 +30,11 @@ from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
@@ -67,21 +61,6 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
# available.
|
||||
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse("0.5")
|
||||
_required_transformers_version = version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
|
||||
|
||||
|
||||
class PatchedLoraProjection(nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
super().__init__()
|
||||
@@ -1098,7 +1077,6 @@ class LoraLoaderMixin:
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
"""
|
||||
@@ -1290,7 +1268,6 @@ class LoraLoaderMixin:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
network_alphas = None
|
||||
# TODO: replace it with a method from `state_dict_utils`
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
@@ -1543,35 +1520,55 @@ class LoraLoaderMixin:
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if cls.use_peft_backend:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
#
|
||||
# Previously, the old LoRA layers were stored on the state dict at the
|
||||
# same level as the attention block i.e.
|
||||
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
|
||||
#
|
||||
# This is no actual module at that point, they were monkey patched on to the
|
||||
# existing module. We want to be able to load them via their actual state dict.
|
||||
# They're in `PatchedLoraProjection.lora_linear_layer` now.
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -1581,79 +1578,56 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if cls.use_peft_backend:
|
||||
from peft import LoraConfig
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
lora_rank = list(rank.values())[0]
|
||||
# By definition, the scale should be alpha divided by rank.
|
||||
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
|
||||
alpha = lora_scale * lora_rank
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
if patch_mlp:
|
||||
target_modules += ["fc1", "fc2"]
|
||||
|
||||
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
|
||||
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
|
||||
|
||||
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -1671,27 +1645,10 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
|
||||
if self.use_peft_backend:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
if self.use_peft_backend:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
@@ -1718,7 +1675,6 @@ class LoraLoaderMixin:
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
@@ -1922,7 +1878,7 @@ class LoraLoaderMixin:
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
||||
if "emb" in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
@@ -1934,13 +1890,6 @@ class LoraLoaderMixin:
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
# LyCORIS specificity.
|
||||
if "time.emb.proj" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
||||
if "conv.shortcut" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
||||
|
||||
# General coverage.
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
@@ -2093,38 +2042,24 @@ class LoraLoaderMixin:
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale)
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
def fuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder_2)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -2146,29 +2081,18 @@ class LoraLoaderMixin:
|
||||
if unfuse_unet:
|
||||
self.unet.unfuse_lora()
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -2879,16 +2803,5 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -19,7 +19,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -167,7 +166,7 @@ class TimestepEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -180,7 +179,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -25,25 +25,18 @@ from ..utils import logging
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module.scaling[module.active_adapter] = lora_scale
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
|
||||
@@ -42,25 +42,9 @@ def rename_key(key):
|
||||
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
||||
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
||||
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
||||
|
||||
# conv norm or layer norm
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
|
||||
# rename attention layers
|
||||
if len(pt_tuple_key) > 1:
|
||||
for rename_from, rename_to in (
|
||||
("to_out_0", "proj_attn"),
|
||||
("to_k", "key"),
|
||||
("to_v", "value"),
|
||||
("to_q", "query"),
|
||||
):
|
||||
if pt_tuple_key[-2] == rename_from:
|
||||
weight_name = pt_tuple_key[-1]
|
||||
weight_name = "kernel" if weight_name == "weight" else weight_name
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
|
||||
if renamed_pt_tuple_key in random_flax_state_dict:
|
||||
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
|
||||
return renamed_pt_tuple_key, pt_tensor.T
|
||||
|
||||
if (
|
||||
any("norm" in str_ for str_ in pt_tuple_key)
|
||||
and (pt_tuple_key[-1] == "bias")
|
||||
|
||||
@@ -303,23 +303,23 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
"framework": "flax",
|
||||
}
|
||||
|
||||
# Load config if we don't provide one
|
||||
if config is None:
|
||||
config, unused_kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
model, model_kwargs = cls.from_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
# model args
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Load model
|
||||
pretrained_path_with_subfolder = (
|
||||
|
||||
@@ -52,7 +52,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
@@ -73,7 +72,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.num_attention_heads,
|
||||
d_head=self.out_channels // self.num_attention_heads,
|
||||
depth=self.transformer_layers_per_block,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
@@ -193,7 +192,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
@@ -215,7 +213,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.num_attention_heads,
|
||||
d_head=self.out_channels // self.num_attention_heads,
|
||||
depth=self.transformer_layers_per_block,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
@@ -333,7 +331,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
use_linear_projection: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
# there is always at least one resnet
|
||||
@@ -353,7 +350,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
in_channels=self.in_channels,
|
||||
n_heads=self.num_attention_heads,
|
||||
d_head=self.in_channels // self.num_attention_heads,
|
||||
depth=self.transformer_layers_per_block,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
|
||||
@@ -883,6 +883,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
@@ -116,11 +116,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
use_memory_efficient_attention: bool = False
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
addition_embed_type_num_heads: int = 64
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -132,17 +127,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if self.addition_embed_type == "text_time":
|
||||
# TODO: how to get this from the config? It's no longer cross_attention_dim
|
||||
text_embeds_dim = 1280
|
||||
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
||||
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
|
||||
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
|
||||
}
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
|
||||
|
||||
def setup(self):
|
||||
block_out_channels = self.block_out_channels
|
||||
@@ -183,24 +168,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
||||
|
||||
# transformer layers per block
|
||||
transformer_layers_per_block = self.transformer_layers_per_block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
|
||||
|
||||
# addition embed types
|
||||
if self.addition_embed_type is None:
|
||||
self.add_embedding = None
|
||||
elif self.addition_embed_type == "text_time":
|
||||
if self.addition_time_embed_dim is None:
|
||||
raise ValueError(
|
||||
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
|
||||
)
|
||||
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
|
||||
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
else:
|
||||
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
|
||||
|
||||
# down
|
||||
down_blocks = []
|
||||
output_channel = block_out_channels[0]
|
||||
@@ -215,7 +182,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
add_downsample=not is_final_block,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
@@ -241,7 +207,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
@@ -253,7 +218,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
for i, up_block_type in enumerate(self.up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
@@ -267,7 +231,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
add_upsample=not is_final_block,
|
||||
dropout=self.dropout,
|
||||
@@ -306,7 +269,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
return_dict: bool = True,
|
||||
@@ -338,31 +300,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
# additional embeddings
|
||||
aug_emb = None
|
||||
if self.addition_embed_type == "text_time":
|
||||
if added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if text_embeds is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
if time_ids is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
# compute time embeds
|
||||
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
|
||||
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
|
||||
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -16,7 +16,7 @@ from ..utils import (
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@@ -67,10 +67,8 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
@@ -142,14 +140,12 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"].extend(
|
||||
[
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
@@ -202,7 +198,6 @@ else:
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -237,11 +232,6 @@ else:
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_xl"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionXLPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -291,9 +281,7 @@ if TYPE_CHECKING:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
@@ -445,9 +433,6 @@ if TYPE_CHECKING:
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
FlaxStableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
|
||||
@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
from .pipeline_blip_diffusion import BlipDiffusionPipeline
|
||||
@@ -1,318 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for BLIP."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from transformers.utils import TensorType, is_vision_available, logging
|
||||
|
||||
from diffusers.utils import numpy_to_pil
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
|
||||
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
|
||||
class BlipImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a BLIP image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
do_center_crop: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_center_crop = do_center_crop
|
||||
|
||||
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_resize and size is None or resample is None:
|
||||
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||
|
||||
if do_rescale and rescale_factor is None:
|
||||
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||
|
||||
if do_normalize and (image_mean is None or image_std is None):
|
||||
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||
|
||||
# PIL RGBA images are converted to RGB
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
return encoded_outputs
|
||||
|
||||
# Follows diffusers.VaeImageProcessor.postprocess
|
||||
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(
|
||||
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
||||
)
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
if output_type == "pt":
|
||||
return sample
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "np":
|
||||
return sample
|
||||
# Output_type must be 'pil'
|
||||
sample = numpy_to_pil(sample)
|
||||
return sample
|
||||
@@ -1,642 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import BertTokenizer
|
||||
from transformers.activations import QuickGELUActivation as QuickGELU
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
|
||||
from transformers.models.blip_2.modeling_blip_2 import (
|
||||
Blip2Encoder,
|
||||
Blip2PreTrainedModel,
|
||||
Blip2QFormerAttention,
|
||||
Blip2QFormerIntermediate,
|
||||
Blip2QFormerOutput,
|
||||
)
|
||||
from transformers.pytorch_utils import apply_chunking_to_forward
|
||||
from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
|
||||
# But it doesn't support getting multimodal embeddings. So, this module can be
|
||||
# replaced with a future `transformers` version supports that.
|
||||
class Blip2TextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word and position embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
query_embeds=None,
|
||||
past_key_values_length=0,
|
||||
):
|
||||
if input_ids is not None:
|
||||
seq_length = input_ids.size()[1]
|
||||
else:
|
||||
seq_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
||||
|
||||
if input_ids is not None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
|
||||
if query_embeds is not None:
|
||||
batch_size = embeddings.shape[0]
|
||||
# repeat the query embeddings for batch size
|
||||
query_embeds = query_embeds.repeat(batch_size, 1, 1)
|
||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||
else:
|
||||
embeddings = query_embeds
|
||||
embeddings = embeddings.to(query_embeds.dtype)
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
||||
class Blip2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
|
||||
class Blip2QFormerEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList(
|
||||
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
query_length=0,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, past_key_value, output_attentions, query_length)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if layer_module.has_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
# The layers making up the Qformer encoder
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx % config.cross_attention_frequency == 0:
|
||||
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate = Blip2QFormerIntermediate(config)
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
self.output = Blip2QFormerOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
query_length=0,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
if self.has_cross_attention:
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
||||
cross_attention_outputs = self.crossattention(
|
||||
query_attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
query_attention_output = cross_attention_outputs[0]
|
||||
# add cross attentions if we output attention weights
|
||||
outputs = outputs + cross_attention_outputs[1:-1]
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk_query,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
query_attention_output,
|
||||
)
|
||||
|
||||
if attention_output.shape[1] > query_length:
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk_query(self, attention_output):
|
||||
intermediate_output = self.intermediate_query(attention_output)
|
||||
layer_output = self.output_query(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
|
||||
class ProjLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
|
||||
super().__init__()
|
||||
|
||||
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
|
||||
self.dense1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.act_fn = QuickGELU()
|
||||
self.dense2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(drop_p)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
|
||||
x = self.LayerNorm(x)
|
||||
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
|
||||
class Blip2VisionModel(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = Blip2VisionConfig
|
||||
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = Blip2VisionEmbeddings(config)
|
||||
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = Blip2Encoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
# Qformer model, used to get multimodal embeddings from the text and image inputs
|
||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in BLIP-2.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||
self.visual_encoder = Blip2VisionModel(config.vision_config)
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
if not hasattr(config, "tokenizer") or config.tokenizer is None:
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
else:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
|
||||
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
self.proj_layer = ProjLayer(
|
||||
in_dim=config.qformer_config.hidden_size,
|
||||
out_dim=config.qformer_config.hidden_size,
|
||||
hidden_dim=config.qformer_config.hidden_size * 4,
|
||||
drop_p=0.1,
|
||||
eps=1e-12,
|
||||
)
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config.qformer_config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int],
|
||||
device: torch.device,
|
||||
has_query: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device (`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_input=None,
|
||||
image_input=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
||||
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
||||
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
||||
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
||||
`(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, `optional`):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
|
||||
text = text.to(self.device)
|
||||
input_ids = text.input_ids
|
||||
batch_size = input_ids.shape[0]
|
||||
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
|
||||
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = (
|
||||
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
||||
)
|
||||
|
||||
query_length = self.query_tokens.shape[1]
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
query_embeds=self.query_tokens,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# embedding_output = self.layernorm(query_embeds)
|
||||
# embedding_output = self.dropout(embedding_output)
|
||||
|
||||
input_shape = embedding_output.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = embedding_output.device
|
||||
|
||||
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
|
||||
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
|
||||
encoder_hidden_states = image_embeds_frozen
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, list):
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
|
||||
if isinstance(encoder_attention_mask, list):
|
||||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||
elif encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
query_length=query_length,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
if not return_dict:
|
||||
return self.proj_layer(sequence_output[:, :query_length, :])
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPPreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
from transformers.models.clip.modeling_clip import (
|
||||
CLIPEncoder,
|
||||
_expand_mask,
|
||||
)
|
||||
|
||||
|
||||
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
|
||||
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
|
||||
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
|
||||
class ContextCLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = ContextCLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor = None,
|
||||
ctx_begin_pos: list = None,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
return self.text_model(
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class ContextCLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = ContextCLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
if ctx_embeddings is not None:
|
||||
seq_len += ctx_embeddings.size(1)
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||
hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
|
||||
input_ids.to(torch.int).argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class ContextCLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if ctx_embeddings is None:
|
||||
ctx_len = 0
|
||||
else:
|
||||
ctx_len = ctx_embeddings.shape[1]
|
||||
|
||||
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
# for each input embeddings, add the ctx embeddings at the correct position
|
||||
input_embeds_ctx = []
|
||||
bsz = inputs_embeds.shape[0]
|
||||
|
||||
if ctx_embeddings is not None:
|
||||
for i in range(bsz):
|
||||
cbp = ctx_begin_pos[i]
|
||||
|
||||
prefix = inputs_embeds[i, :cbp]
|
||||
# remove the special token embedding
|
||||
suffix = inputs_embeds[i, cbp:]
|
||||
|
||||
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
|
||||
|
||||
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
@@ -1,339 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
|
||||
>>> cond_subject = "dog"
|
||||
>>> tgt_subject = "dog"
|
||||
>>> text_prompt_input = "swimming underwater"
|
||||
|
||||
>>> cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 25
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt_input,
|
||||
... cond_image,
|
||||
... cond_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -1,79 +1,77 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -1,405 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from controlnet_aux import CannyDetector
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> style_subject = "flower"
|
||||
>>> tgt_subject = "teapot"
|
||||
>>> text_prompt = "on a marble table"
|
||||
|
||||
>>> cldm_cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
|
||||
... ).resize((512, 512))
|
||||
>>> canny = CannyDetector()
|
||||
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
|
||||
>>> style_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 50
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt,
|
||||
... style_image,
|
||||
... cldm_cond_image,
|
||||
... style_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
controlnet ([`ControlNetModel`]):
|
||||
ControlNet model to get the conditioning image embedding.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
controlnet: ControlNetModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
controlnet=controlnet,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
):
|
||||
image = self.image_processor.preprocess(
|
||||
image,
|
||||
size={"width": width, "height": height},
|
||||
do_rescale=True,
|
||||
do_center_crop=False,
|
||||
do_normalize=False,
|
||||
return_tensors="pt",
|
||||
)["pixel_values"].to(self.device)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
condtioning_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
condtioning_image (`PIL.Image.Image`):
|
||||
The conditioning canny edge image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
seed (`int`, *optional*, defaults to 42):
|
||||
The seed to use for random generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
# 3. unconditional embedding
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
cond_image = self.prepare_control_image(
|
||||
image=condtioning_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=1,
|
||||
device=self.device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=cond_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -315,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
@@ -394,29 +394,10 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
# Throw nice warnings / errors for fast accelerate loading
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
init_kwargs = {}
|
||||
|
||||
# inference_params
|
||||
params = {}
|
||||
|
||||
@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -297,7 +297,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -211,7 +211,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -609,9 +609,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
|
||||
sampler_kwargs["noise_sampler"] = noise_sampler
|
||||
|
||||
if "generator" in inspect.signature(self.sampler).parameters:
|
||||
sampler_kwargs["generator"] = generator
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
|
||||
|
||||
if not output_type == "latent":
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -256,7 +256,7 @@ class StableDiffusionParadigmsPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -240,7 +240,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -346,7 +346,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -4,18 +4,14 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]}
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"])
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -29,12 +25,6 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
|
||||
_additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
|
||||
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
@@ -48,17 +38,6 @@ if TYPE_CHECKING:
|
||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_objects import *
|
||||
else:
|
||||
from .pipeline_flax_stable_diffusion_xl import (
|
||||
FlaxStableDiffusionXLPipeline,
|
||||
)
|
||||
from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -71,5 +50,3 @@ else:
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -1,306 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from transformers import CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from . import FlaxStableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
text_encoder_2: FlaxCLIPTextModel,
|
||||
vae: FlaxAutoencoderKL,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
# Assume we have the two encoders
|
||||
inputs = []
|
||||
for tokenizer in [self.tokenizer, self.tokenizer_2]:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
inputs.append(text_inputs.input_ids)
|
||||
inputs = jnp.stack(inputs, axis=1)
|
||||
return inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jax.Array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jax.Array] = 7.5,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
output_type: str = None,
|
||||
jit: bool = False,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(guidance_scale, float) and jit:
|
||||
# Convert to a tensor so each device gets a copy.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
return_latents = output_type == "latent"
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
return_latents,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
return_latents,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return FlaxStableDiffusionXLPipelineOutput(images=images)
|
||||
|
||||
def get_embeddings(self, prompt_ids: jnp.array, params):
|
||||
# We assume we have the two encoders
|
||||
|
||||
# bs, encoder_input, seq_length
|
||||
te_1_inputs = prompt_ids[:, 0, :]
|
||||
te_2_inputs = prompt_ids[:, 1, :]
|
||||
|
||||
prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True)
|
||||
prompt_embeds = prompt_embeds["hidden_states"][-2]
|
||||
prompt_embeds_2_out = self.text_encoder_2(
|
||||
te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True
|
||||
)
|
||||
prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
|
||||
text_embeds = prompt_embeds_2_out["text_embeds"]
|
||||
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
|
||||
return prompt_embeds, text_embeds
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
return_latents=False,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# Encode input prompt
|
||||
prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params)
|
||||
|
||||
# Get unconditional embeddings
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if neg_prompt_ids is None:
|
||||
neg_prompt_ids = self.prepare_inputs([""] * batch_size)
|
||||
|
||||
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048)
|
||||
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
|
||||
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
|
||||
|
||||
# Ensure model output will be `float32` before going into the scheduler
|
||||
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
|
||||
|
||||
# Create random latents
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
# Prepare scheduler state
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# Denoising loop
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
if return_latents:
|
||||
return latents
|
||||
|
||||
# Decode latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
|
||||
# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None),
|
||||
static_broadcasted_argnums=(0, 4, 5, 6, 10),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
return_latents,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
return_latents,
|
||||
)
|
||||
@@ -4,11 +4,7 @@ from typing import List, Union
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_flax_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,19 +19,3 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
import flax
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxStableDiffusionXLPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Flax Stable Diffusion XL pipelines.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`)
|
||||
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: np.ndarray
|
||||
|
||||
@@ -264,8 +264,8 @@ class StableDiffusionXLPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
@@ -271,8 +271,8 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -722,16 +722,16 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
|
||||
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
|
||||
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
|
||||
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
|
||||
denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
|
||||
final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
|
||||
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
|
||||
@@ -420,8 +420,8 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
|
||||
+2
-2
@@ -272,8 +272,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -288,8 +288,8 @@ class StableDiffusionXLAdapterPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -787,16 +787,8 @@ class StableDiffusionXLAdapterPipeline(
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
device = self._execution_device
|
||||
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_input = []
|
||||
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
|
||||
|
||||
for one_image in image:
|
||||
one_image = _preprocess_adapter_image(one_image, height, width)
|
||||
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
|
||||
adapter_input.append(one_image)
|
||||
else:
|
||||
adapter_input = _preprocess_adapter_image(image, height, width)
|
||||
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -873,14 +865,10 @@ class StableDiffusionXLAdapterPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings & adapter features
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v
|
||||
else:
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
adapter_input = adapter_input.type(latents.dtype)
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if num_images_per_prompt > 1:
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
@@ -228,7 +228,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -290,7 +290,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -1094,6 +1094,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
@@ -76,7 +76,6 @@ else:
|
||||
_import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
|
||||
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
|
||||
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
|
||||
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
|
||||
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
|
||||
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
|
||||
|
||||
@@ -22,7 +22,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -187,14 +186,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -234,17 +225,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
@@ -253,9 +245,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -291,57 +280,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DEIS algorithm needs.
|
||||
@@ -358,26 +298,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -389,6 +316,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
return (sample - alpha_t * x0_pred) / sigma_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
@@ -396,9 +324,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def deis_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DEIS (equivalent to DDIM).
|
||||
@@ -417,33 +345,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "deis":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
@@ -454,9 +358,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DEIS.
|
||||
@@ -464,6 +368,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -471,38 +379,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
|
||||
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
|
||||
|
||||
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
|
||||
|
||||
@@ -523,9 +403,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DEIS.
|
||||
@@ -533,6 +413,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -540,47 +424,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
|
||||
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
|
||||
rho_t, rho_s0, rho_s1, rho_s2 = (
|
||||
sigma_t / alpha_t,
|
||||
sigma_s0 / alpha_s0,
|
||||
sigma_s1 / alpha_s1,
|
||||
sigma_s2 / alpha_s2,
|
||||
simga_s2 / alpha_s2,
|
||||
)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
@@ -608,25 +460,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -659,34 +492,42 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.deis_first_order_update(model_output, sample=sample)
|
||||
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -707,31 +548,28 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -204,14 +203,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -251,20 +242,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
@@ -273,9 +264,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -335,12 +323,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -356,11 +338,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -377,6 +355,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -384,18 +364,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -403,14 +371,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -432,12 +398,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -446,8 +410,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
@@ -457,10 +420,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -468,6 +431,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -475,33 +442,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -526,10 +469,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -537,6 +480,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -544,43 +491,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -649,9 +564,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -659,6 +574,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -666,47 +585,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -731,25 +619,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -785,17 +654,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -808,18 +682,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -840,30 +719,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -204,16 +203,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -253,19 +244,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = timesteps.copy().astype(np.int64)
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_max = (
|
||||
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
|
||||
) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
@@ -274,7 +257,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
@@ -283,9 +266,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -345,13 +325,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -368,11 +341,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -389,6 +358,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -396,18 +367,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -415,14 +374,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -444,12 +401,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -458,22 +413,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
|
||||
return epsilon
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -481,6 +434,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -488,62 +445,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -551,6 +473,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -558,43 +484,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -626,47 +520,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -674,6 +540,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -681,47 +551,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -746,27 +585,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -786,8 +604,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -802,17 +618,24 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = (
|
||||
self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
)
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -825,18 +648,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -858,31 +686,28 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import logging
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -233,13 +232,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orders = [1] * steps
|
||||
return orders
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -264,18 +256,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [None] * self.config.solver_order
|
||||
self.sample = None
|
||||
|
||||
@@ -287,9 +274,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.order_list = self.get_order_list(num_inference_steps)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -349,13 +333,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -371,11 +348,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -392,6 +365,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -399,32 +374,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -444,13 +405,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output = model_output[:, :3]
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -462,9 +421,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -483,31 +442,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -518,9 +455,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -540,42 +477,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
|
||||
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
|
||||
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
|
||||
@@ -612,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -634,47 +540,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
|
||||
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
|
||||
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m2
|
||||
@@ -716,10 +591,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the singlestep DPMSolver.
|
||||
@@ -740,60 +615,19 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_second_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
elif order == 3:
|
||||
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_third_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -826,15 +660,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
order = self.order_list[self.step_index]
|
||||
order = self.order_list[step_index]
|
||||
|
||||
# For img2img denoising might start with order>1 which is not possible
|
||||
# In this case make sure that the first two steps are both order=1
|
||||
@@ -845,10 +685,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
self.sample = sample
|
||||
|
||||
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
|
||||
prev_sample = self.singlestep_dpm_solver_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, self.sample, order
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
@@ -870,31 +710,28 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class EulerDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
|
||||
# setable values
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
sigmas: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: EulerDiscreteSchedulerState
|
||||
|
||||
|
||||
class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
|
||||
k-diffusion implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
|
||||
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`jnp.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
|
||||
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
init_noise_sigma = sigmas.max()
|
||||
else:
|
||||
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
return EulerDiscreteSchedulerState.create(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
|
||||
"""
|
||||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
|
||||
Args:
|
||||
state (`EulerDiscreteSchedulerState`):
|
||||
the `FlaxEulerDiscreteScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
timestep (`int`):
|
||||
current discrete timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
||||
step_index = step_index[0]
|
||||
|
||||
sigma = state.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> EulerDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
state (`EulerDiscreteSchedulerState`):
|
||||
the `FlaxEulerDiscreteScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}"
|
||||
)
|
||||
|
||||
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
|
||||
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
init_noise_sigma = sigmas.max()
|
||||
else:
|
||||
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
return state.replace(
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
num_inference_steps=num_inference_steps,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
state: EulerDiscreteSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
state (`EulerDiscreteSchedulerState`):
|
||||
the `FlaxEulerDiscreteScheduler` state data class instance.
|
||||
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class
|
||||
|
||||
Returns:
|
||||
[`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
||||
step_index = step_index[0]
|
||||
|
||||
sigma = state.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
# dt = sigma_down - sigma
|
||||
dt = state.sigmas[step_index + 1] - sigma
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
state: EulerDiscreteSchedulerState,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sigma = state.sigmas[timesteps].flatten()
|
||||
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -89,9 +89,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -116,7 +113,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -247,15 +243,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -279,13 +269,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -298,44 +282,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
@@ -88,9 +88,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -115,7 +112,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -247,14 +243,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -269,12 +260,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -287,6 +273,29 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
@@ -309,44 +318,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
|
||||
@@ -22,16 +22,10 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -44,30 +38,19 @@ def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -198,14 +181,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -245,17 +220,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
@@ -267,9 +243,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -329,13 +302,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -351,11 +317,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Convert the model output to the corresponding type the UniPC algorithm needs.
|
||||
@@ -372,28 +334,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -409,9 +357,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -423,10 +373,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
||||
@@ -445,26 +394,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 2:
|
||||
order = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0 = self.timestep_list[-1]
|
||||
s0, t = self.timestep_list[-1], prev_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
@@ -472,12 +405,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = sample.device
|
||||
@@ -485,10 +415,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - i
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -552,11 +481,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: torch.FloatTensor,
|
||||
*args,
|
||||
last_sample: torch.FloatTensor = None,
|
||||
this_sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
this_timestep: int,
|
||||
last_sample: torch.FloatTensor,
|
||||
this_sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniC (B(h) version).
|
||||
@@ -577,42 +505,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The corrected sample tensor at the current timestep.
|
||||
"""
|
||||
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
||||
if last_sample is None:
|
||||
if len(args) > 1:
|
||||
last_sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing`last_sample` as a required keyward argument")
|
||||
if this_sample is None:
|
||||
if len(args) > 2:
|
||||
this_sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`this_sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing`order` as a required keyward argument")
|
||||
if this_timestep is not None:
|
||||
deprecate(
|
||||
"this_timestep",
|
||||
"1.0.0",
|
||||
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = timestep_list[-1], this_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = this_sample.device
|
||||
@@ -620,10 +524,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - (i + 1)
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -686,25 +589,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -732,27 +616,37 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
|
||||
use_corrector = (
|
||||
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
||||
model_output_convert = self.convert_model_output(model_output, timestep, sample)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
this_timestep=timestep,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
# now prepare to run the predictor
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
@@ -761,7 +655,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config.lower_order_final:
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
|
||||
else:
|
||||
this_order = self.config.solver_order
|
||||
|
||||
@@ -771,6 +665,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
||||
prev_timestep=prev_timestep,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
@@ -778,9 +673,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -801,31 +693,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -37,7 +37,6 @@ class FlaxKarrasDiffusionSchedulers(Enum):
|
||||
FlaxPNDMScheduler = 3
|
||||
FlaxLMSDiscreteScheduler = 4
|
||||
FlaxDPMSolverMultistepScheduler = 5
|
||||
FlaxEulerDiscreteScheduler = 6
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -67,7 +67,6 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_omegaconf_available,
|
||||
is_onnx_available,
|
||||
is_peft_available,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_torch_available,
|
||||
@@ -83,9 +82,7 @@ from .import_utils import (
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import recurse_remove_peft_layers
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -60,18 +60,3 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@@ -122,21 +122,6 @@ class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxEulerDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxKarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
@@ -315,36 +315,6 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -267,14 +267,6 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_invisible_watermark_available = False
|
||||
|
||||
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
try:
|
||||
_peft_version = importlib_metadata.version("peft")
|
||||
logger.debug(f"Successfully imported peft version {_peft_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -359,10 +351,6 @@ def is_invisible_watermark_available():
|
||||
return _invisible_watermark_available
|
||||
|
||||
|
||||
def is_peft_available():
|
||||
return _peft_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def recurse_remove_peft_layers(model):
|
||||
r"""
|
||||
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
||||
"""
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
## compound module, go inside it
|
||||
recurse_remove_peft_layers(module)
|
||||
|
||||
module_replaced = False
|
||||
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
module.dilation,
|
||||
module.groups,
|
||||
module.bias,
|
||||
).to(module.weight.device)
|
||||
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
|
||||
if module_replaced:
|
||||
setattr(model, name, new_module)
|
||||
del module
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
@@ -1,184 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
import enum
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
The mode to use when converting state dicts.
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
# KOHYA_SS = "kohya_ss" # TODO: implement this
|
||||
PEFT = "peft"
|
||||
DIFFUSERS = "diffusers"
|
||||
|
||||
|
||||
DIFFUSERS_TO_PEFT = {
|
||||
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
|
||||
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
|
||||
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
|
||||
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
|
||||
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
|
||||
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
|
||||
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
|
||||
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_PEFT = {
|
||||
".to_q_lora.up": ".q_proj.lora_B",
|
||||
".to_q_lora.down": ".q_proj.lora_A",
|
||||
".to_k_lora.up": ".k_proj.lora_B",
|
||||
".to_k_lora.down": ".k_proj.lora_A",
|
||||
".to_v_lora.up": ".v_proj.lora_B",
|
||||
".to_v_lora.down": ".v_proj.lora_A",
|
||||
".to_out_lora.up": ".out_proj.lora_B",
|
||||
".to_out_lora.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
PEFT_TO_DIFFUSERS = {
|
||||
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
|
||||
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
|
||||
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
|
||||
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
|
||||
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
|
||||
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
||||
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
||||
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
|
||||
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
|
||||
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
|
||||
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
|
||||
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
|
||||
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
||||
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
PEFT_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
|
||||
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
|
||||
}
|
||||
|
||||
DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict(state_dict, mapping):
|
||||
r"""
|
||||
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
mapping (`dict[str, str]`):
|
||||
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
|
||||
- key: the pattern to replace
|
||||
- value: the pattern to replace with
|
||||
|
||||
Returns:
|
||||
converted_state_dict (`dict`)
|
||||
The converted state dict.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for pattern in mapping.keys():
|
||||
if pattern in k:
|
||||
new_pattern = mapping[pattern]
|
||||
k = k.replace(pattern, new_pattern)
|
||||
break
|
||||
converted_state_dict[k] = v
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
|
||||
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
"""
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
|
||||
|
||||
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
|
||||
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
|
||||
return the state dict as is.
|
||||
|
||||
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
peft_adapter_name = kwargs.pop("adapter_name", None)
|
||||
if peft_adapter_name is not None:
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
else:
|
||||
peft_adapter_name = ""
|
||||
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
# nothing to do
|
||||
return state_dict
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
@@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
@@ -30,11 +29,9 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_peft_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
@@ -43,15 +40,6 @@ global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse("0.5")
|
||||
_required_transformers_version = is_transformers_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -248,21 +236,6 @@ def require_torchsde(test_case):
|
||||
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
|
||||
|
||||
|
||||
def require_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
|
||||
transformers.
|
||||
"""
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
"""
|
||||
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
|
||||
|
||||
|
||||
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
|
||||
if isinstance(arry, str):
|
||||
# local_path = "/home/patrick_huggingface_co/"
|
||||
|
||||
@@ -52,15 +52,7 @@ from diffusers.models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
deprecate_after_peft_backend,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
def create_lora_layers(model, mock_weights: bool = True):
|
||||
@@ -189,7 +181,6 @@ def state_dicts_almost_equal(sd1, sd2):
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@deprecate_after_peft_backend
|
||||
class LoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -782,7 +773,6 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - image_slice_2).max() > 1e-2
|
||||
|
||||
|
||||
@deprecate_after_peft_backend
|
||||
class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -1886,25 +1876,6 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_lycoris(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
|
||||
).to(torch_device)
|
||||
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
|
||||
lora_filename = "edgLycorisMugler-light.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_a1111_with_model_cpu_offload(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user