Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair abf4a9271e skip test 2023-09-19 12:39:40 +00:00
Dhruv Nair 0e1fb0d916 merge upstream 2023-09-19 11:27:08 +00:00
Dhruv Nair f77b7a0f27 fix tests 2023-09-19 04:32:19 +00:00
Dhruv Nair eae1371983 wip 2023-09-19 03:37:22 +00:00
244 changed files with 1345 additions and 8185 deletions
+1 -2
View File
@@ -15,7 +15,6 @@ body:
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
- 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
- type: markdown
attributes:
value: |
@@ -71,7 +70,7 @@ body:
Questions on schedulers: @patrickvonplaten and @williamberman
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman (for community pipelines, please tag the original author of the pipeline)
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman
Questions on JAX- and MPS-related things: @pcuenca
@@ -26,7 +26,6 @@ jobs:
image-name:
- diffusers-pytorch-cpu
- diffusers-pytorch-cuda
- diffusers-pytorch-compile-cuda
- diffusers-flax-cpu
- diffusers-flax-tpu
- diffusers-onnxruntime-cpu
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -1,67 +0,0 @@
name: Fast tests for PRs - PEFT backend
on:
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install -U git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
+2 -46
View File
@@ -74,11 +74,11 @@ jobs:
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not compile" \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
@@ -113,50 +113,6 @@ jobs:
name: ${{ matrix.config.report }}_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
runs-on: docker-gpu
container:
image: diffusers/diffusers-pytorch-compile-cuda
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m pip install -e .[quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_compile_test_reports
path: reports
run_examples_tests:
name: Examples PyTorch CUDA tests on Ubuntu
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.7
- name: Install requirements
run: |
@@ -1,47 +0,0 @@
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.9 \
python3-pip \
python3.9-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
invisible_watermark && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy \
scipy \
tensorboard \
transformers \
omegaconf \
pytorch-lightning \
xformers
CMD ["/bin/bash"]
-2
View File
@@ -216,8 +216,6 @@
title: AudioLDM 2
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP Diffusion
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
+13 -13
View File
@@ -67,30 +67,30 @@ By default, `tqdm` progress bars are displayed during model download. [`logging.
## Base setters
[[autodoc]] utils.logging.set_verbosity_error
[[autodoc]] logging.set_verbosity_error
[[autodoc]] utils.logging.set_verbosity_warning
[[autodoc]] logging.set_verbosity_warning
[[autodoc]] utils.logging.set_verbosity_info
[[autodoc]] logging.set_verbosity_info
[[autodoc]] utils.logging.set_verbosity_debug
[[autodoc]] logging.set_verbosity_debug
## Other functions
[[autodoc]] utils.logging.get_verbosity
[[autodoc]] logging.get_verbosity
[[autodoc]] utils.logging.set_verbosity
[[autodoc]] logging.set_verbosity
[[autodoc]] utils.logging.get_logger
[[autodoc]] logging.get_logger
[[autodoc]] utils.logging.enable_default_handler
[[autodoc]] logging.enable_default_handler
[[autodoc]] utils.logging.disable_default_handler
[[autodoc]] logging.disable_default_handler
[[autodoc]] utils.logging.enable_explicit_format
[[autodoc]] logging.enable_explicit_format
[[autodoc]] utils.logging.reset_format
[[autodoc]] logging.reset_format
[[autodoc]] utils.logging.enable_progress_bar
[[autodoc]] logging.enable_progress_bar
[[autodoc]] utils.logging.disable_progress_bar
[[autodoc]] logging.disable_progress_bar
@@ -1,29 +0,0 @@
# Blip Diffusion
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
The abstract from the paper is:
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
<Tip>
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## BlipDiffusionPipeline
[[autodoc]] BlipDiffusionPipeline
- all
- __call__
## BlipDiffusionControlNetPipeline
[[autodoc]] BlipDiffusionControlNetPipeline
- all
- __call__
+6
View File
@@ -34,7 +34,13 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
- load_lora_weights
- save_lora_weights
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
## StableDiffusionXLInstructPix2PixPipeline
[[autodoc]] StableDiffusionXLInstructPix2PixPipeline
- __call__
- all
## StableDiffusionXLPipelineOutput
[[autodoc]] pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput
@@ -31,5 +31,5 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
- __call__
## StableDiffusionSafePipelineOutput
[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput
- all
[[autodoc]] pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput
- all
+1 -14
View File
@@ -2,7 +2,7 @@
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9">
[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.
[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville.
The abstract from the paper is:
@@ -134,16 +134,3 @@ The original codebase, as well as experimental ideas, can be found at [dome272/W
[[autodoc]] WuerstchenDecoderPipeline
- all
- __call__
## Citation
```bibtex
@misc{pernias2023wuerstchen,
title={Wuerstchen: Efficient Pretraining of Text-to-Image Models},
author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher Pal and Marc Aubreville},
year={2023},
eprint={2306.00637},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
Install 🤗 Diffusers for whichever deep learning library you're working with.
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
These commands will link the folder you cloned the repository to and your Python library paths.
Python will now look inside the folder you cloned to in addition to the normal library paths.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
<Tip warning={true}>
+62 -568
View File
@@ -10,597 +10,91 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Image-to-image
# Text-guided image-to-image generation
[[open-in-colab]]
Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
With 🤗 Diffusers, this is as easy as 1-2-3:
1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
Before you begin, make sure you have all the necessary libraries installed:
```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# uncomment to install the necessary libraries in Colab
#!pip install diffusers transformers ftfy accelerate
```
Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model like [`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion).
```python
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16, use_safetensors=True
).to(device)
```
Download and preprocess an initial image so you can pass it to the pipeline:
```python
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image.thumbnail((768, 768))
init_image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_8_output_0.jpeg"/>
</div>
<Tip>
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](/optimization/torch2.0#scaled-dot-product-attention).
💡 `strength` is a value between 0.0 and 1.0 that controls the amount of noise added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
</Tip>
2. Load an image to pass to the pipeline:
Define the prompt (for this checkpoint finetuned on Ghibli-style art, you need to prefix the prompt with the `ghibli style` tokens) and run the pipeline:
```py
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
```
3. Pass a prompt and image to the pipeline to generate an image:
```py
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Popular models
The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
### Stable Diffusion v1.5
Stable Diffusion v1.5 is a latent diffusion model intialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdv1.5.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
### Stable Diffusion XL (SDXL)
SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
### Kandinsky 2.2
The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
The simplest way to use Kandinsky 2.2 is:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-kandinsky.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Configure pipeline parameters
There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
### Strength
`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
- 📉 a lower `strength` value means the generated image is more similar to the initial image
The `strength` and `num_inference_steps` parameter are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = init_image
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.8).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.4.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.4</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.6.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.6</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-1.0.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 1.0</figcaption>
</div>
</div>
### Guidance scale
The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-0.1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 0.1</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-3.0.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 5.0</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-7.5.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 10.0</figcaption>
</div>
</div>
### Negative prompt
A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
# pass prompt and image to pipeline
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "jungle"</figcaption>
</div>
</div>
## Chained image-to-image pipelines
There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
### Text-to-image-to-image
Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
Start by generating an image with the text-to-image pipeline:
```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
```
Now you can pass this generated image to the image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=image).images[0]
image
```
### Image-to-image-to-image
You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generate short GIFs, restore color to an image, or restore missing areas of an image.
Start by generating an image:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
<Tip>
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
</Tip>
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
```py
pipelne = AutoPipelineForImage2Image.from_pretrained(
"ogkalu/Comic-Diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
```
Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"kohbanye/pixel-art-style", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "pixelartstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
image
```
### Image-to-upscaler-to-super-resolution
Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
Start with an image-to-image pipeline:
```py
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
<Tip>
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
</Tip>
Chain it to an upscaler pipeline to increase the image resolution:
```py
upscaler = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
upscaler.enable_model_cpu_offload()
upscaler.enable_xformers_memory_efficient_attention()
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
```
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
```py
super_res = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
super_res.enable_model_cpu_offload()
super_res.enable_xformers_memory_efficient_attention()
image_3 = upscaler(prompt, image=image_2).images[0]
image_3
```
## Control image generation
Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
### Prompt weighting
Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
```py
from diffusers import AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
negative_prompt_embeds, # generated from Compel
image=init_image,
).images[0]
```
### ControlNet
ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
For example, let's condition an image with a depth map to keep the spatial information in the image.
```py
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((958, 960)) # resize to depth image dimensions
depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
```
Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
```py
from diffusers import ControlNetModel, AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
```
Now generate a new image conditioned on the depth map, initial image, and prompt:
```py
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-controlnet.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ControlNet image</figcaption>
</div>
</div>
Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image, strength=0.45, guidance_scale=10.5).images[0]
```python
prompt = "ghibli style, a fantasy landscape with castles"
generator = torch.Generator(device=device).manual_seed(1024)
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-elden-ring.png">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ghibli-castles.png"/>
</div>
## Optimize
You can also try experimenting with a different scheduler to see how that affects the output:
Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](optimization/torch2.0#scaled-dot-product-attention) or [xFormers](optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
```python
from diffusers import LMSDiscreteScheduler
```diff
+ pipeline.enable_model_cpu_offload()
+ pipeline.enable_xformers_memory_efficient_attention()
lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.scheduler = lms
generator = torch.Generator(device=device).manual_seed(1024)
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
image
```
With [`torch.compile`](optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lms-ghibli.png"/>
</div>
```py
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
Check out the Spaces below, and try generating images with different values for `strength`. You'll notice that using lower values for `strength` produces images that are more similar to the original image.
To learn more, take a look at the [Reduce memory usage](optimization/memory) and [Torch 2.0](optimization/torch2.0) guides.
Feel free to also switch the scheduler to the [`LMSDiscreteScheduler`] and see how that affects the output.
<iframe
src="https://stevhliu-ghibli-img2img.hf.space"
frameborder="0"
width="850"
height="500"
></iframe>
@@ -0,0 +1,19 @@
# What is safetensors ?
[safetensors](https://github.com/huggingface/safetensors) is a different format
from the classic `.bin` which uses Pytorch which uses pickle.
Pickle is notoriously unsafe which allow any malicious file to execute arbitrary code.
The hub itself tries to prevent issues from it, but it's not a silver bullet.
`safetensors` first and foremost goal is to make loading machine learning models *safe*
in the sense that no takeover of your computer can be done.
# Why use safetensors ?
**Safety** can be one reason, if you're attempting to use a not well known model and
you're not sure about the source of the file.
And a secondary reason, is **the speed of loading**. Safetensors can load models much faster
than regular pickle files. If you spend a lot of times switching models, this can be
a huge timesave.
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
<Tip warning={true}>
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
这些命令将连接到你克隆的版本库和你的 Python 库路径。
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`
<Tip warning={true}>
@@ -3,7 +3,7 @@ import inspect
from typing import Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from torch.nn import functional as F
from torchvision import transforms
@@ -2,7 +2,7 @@ import inspect
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from torch import nn
from torch.nn import functional as F
@@ -14,7 +14,7 @@
from typing import List, Optional, Tuple, Union
import PIL.Image
import PIL
import torch
from torchvision import transforms
@@ -7,7 +7,7 @@ import warnings
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from accelerate import Accelerator
+1 -1
View File
@@ -2,7 +2,7 @@ import inspect
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+1 -1
View File
@@ -3,7 +3,7 @@ import re
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
@@ -3,7 +3,7 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTokenizer
@@ -1029,7 +1029,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
@@ -1039,7 +1039,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
Examples:
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from diffusers import StableDiffusionImg2ImgPipeline
+1 -1
View File
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import kornia
import numpy as np
import PIL.Image
import PIL
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
@@ -16,7 +16,7 @@ import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
@@ -24,7 +24,7 @@ from typing import List, Optional, Union
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import PIL.Image
import PIL
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
@@ -24,7 +24,7 @@ from typing import List, Optional, Union
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import PIL.Image
import PIL
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
+1 -1
View File
@@ -1,6 +1,6 @@
from typing import Callable, List, Optional, Union
import PIL.Image
import PIL
import torch
from transformers import (
CLIPImageProcessor,
+1 -1
View File
@@ -16,7 +16,7 @@ import math
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
@@ -1,7 +1,7 @@
import inspect
from typing import List, Optional, Union
import PIL.Image
import PIL
import torch
from torch.nn import functional as F
from transformers import (
+1 -11
View File
@@ -907,17 +907,7 @@ def main():
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
snr_loss_weights = base_weights + 1
else:
# Epsilon and sample prediction use the base weights.
snr_loss_weights = base_weights
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
snr_loss_weights[snr == 0] = 1.0
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
loss = loss * snr_loss_weights
loss = loss.mean()
+6 -54
View File
@@ -224,30 +224,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR
snr = (alpha / sigma) ** 2
return snr
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -548,13 +524,6 @@ def parse_args(input_args=None):
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--pre_compute_text_embeddings",
action="store_true",
@@ -1292,34 +1261,17 @@ def main(args):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Compute instance loss
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps, noise_scheduler)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -801,22 +801,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -654,22 +654,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -685,22 +685,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -833,22 +833,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -872,21 +872,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample prediction use the base weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
+1 -14
View File
@@ -952,22 +952,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -783,22 +783,9 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -1072,22 +1072,9 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -325,55 +325,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--timestep_bias_strategy",
type=str,
default="none",
choices=["earlier", "later", "range", "none"],
help=(
"The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
" Choices: ['earlier', 'later', 'range', 'none']."
" The default is 'none', which means no bias is applied, and training proceeds normally."
" The value of 'later' will increase the frequency of the model's final training timesteps."
),
)
parser.add_argument(
"--timestep_bias_multiplier",
type=float,
default=1.0,
help=(
"The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
),
)
parser.add_argument(
"--timestep_bias_begin",
type=int,
default=0,
help=(
"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
" Defaults to zero, which equates to having no specific bias."
),
)
parser.add_argument(
"--timestep_bias_end",
type=int,
default=1000,
help=(
"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
),
)
parser.add_argument(
"--timestep_bias_portion",
type=float,
default=0.25,
help=(
"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
" whether the biased portions are in the earlier or later timesteps."
),
)
parser.add_argument(
"--snr_gamma",
type=float,
@@ -381,6 +332,15 @@ def parse_args(input_args=None):
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--force_snr_gamma",
action="store_true",
help=(
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--allow_tf32",
@@ -528,47 +488,6 @@ def compute_vae_encodings(batch, vae):
return {"model_input": model_input.cpu()}
def generate_timestep_weights(args, num_timesteps):
weights = torch.ones(num_timesteps)
# Determine the indices to bias
num_to_bias = int(args.timestep_bias_portion * num_timesteps)
if args.timestep_bias_strategy == "later":
bias_indices = slice(-num_to_bias, None)
elif args.timestep_bias_strategy == "earlier":
bias_indices = slice(0, num_to_bias)
elif args.timestep_bias_strategy == "range":
# Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
range_begin = args.timestep_bias_begin
range_end = args.timestep_bias_end
if range_begin < 0:
raise ValueError(
"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
)
if range_end > num_timesteps:
raise ValueError(
"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
)
bias_indices = slice(range_begin, range_end)
else: # 'none' or any other string
return weights
if args.timestep_bias_multiplier <= 0:
return ValueError(
"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
" A timestep bias multiplier less than or equal to 0 is not allowed."
)
# Apply the bias
weights[bias_indices] *= args.timestep_bias_multiplier
# Normalize
weights /= weights.sum()
return weights
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -635,6 +554,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
if (
args.snr_gamma
and not args.force_snr_gamma
and (
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
)
):
raise ValueError(
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
@@ -1025,18 +956,11 @@ def main(args):
)
bsz = model_input.shape[0]
if args.timestep_bias_strategy == "none":
# Sample a random timestep for each image without bias.
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
else:
# Sample a random timestep for each image, potentially biased by the timestep weights.
# Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
model_input.device
)
timesteps = torch.multinomial(weights, bsz, replacement=True).long()
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -1074,11 +998,6 @@ def main(args):
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
elif noise_scheduler.config.prediction_type == "sample":
# We set the target to latents here, but the model_pred will return the noise sample prediction.
target = model_input
# We will have to subtract the noise residual from the prediction to get the target sample.
model_pred = model_pred - noise
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
@@ -1089,22 +1008,9 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
@@ -1,343 +0,0 @@
"""
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
"""
import argparse
import os
import tempfile
import torch
from lavis.models import load_model_and_preprocess
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from diffusers import (
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
BLIP2_CONFIG = {
"vision_config": {
"hidden_size": 1024,
"num_hidden_layers": 23,
"num_attention_heads": 16,
"image_size": 224,
"patch_size": 14,
"intermediate_size": 4096,
"hidden_act": "quick_gelu",
},
"qformer_config": {
"cross_attention_frequency": 1,
"encoder_hidden_size": 1024,
"vocab_size": 30523,
},
"num_query_tokens": 16,
}
blip2config = Blip2Config(**BLIP2_CONFIG)
def qformer_model_from_original_config():
qformer = Blip2QFormerModel(blip2config)
return qformer
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
embeddings = {}
embeddings.update(
{
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
f"{original_embeddings_prefix}.word_embeddings.weight"
]
}
)
embeddings.update(
{
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
f"{original_embeddings_prefix}.position_embeddings.weight"
]
}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
)
return embeddings
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
proj_layer = {}
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
return proj_layer
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
attention = {}
attention.update(
{
f"{diffuser_attention_prefix}.attention.query.weight": model[
f"{original_attention_prefix}.self.query.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.attention.value.weight": model[
f"{original_attention_prefix}.self.value.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
f"{original_attention_prefix}.output.LayerNorm.weight"
]
}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
f"{original_attention_prefix}.output.LayerNorm.bias"
]
}
)
return attention
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
output_layers = {}
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
)
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
)
return output_layers
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
encoder = {}
for i in range(blip2config.qformer_config.num_hidden_layers):
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
)
)
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
)
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
]
}
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
)
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
)
)
return encoder
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder_layer = {}
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
)
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
return visual_encoder_layer
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder = {}
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
.unsqueeze(0)
.unsqueeze(0)
}
)
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.position_embedding": model[
f"{original_prefix}.positional_embedding"
].unsqueeze(0)
}
)
visual_encoder.update(
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
)
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
for i in range(blip2config.vision_config.num_hidden_layers):
visual_encoder.update(
visual_encoder_layer_from_original_checkpoint(
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
)
)
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
return visual_encoder
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
qformer_checkpoint = {}
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
qformer_checkpoint.update(
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
)
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
return qformer_checkpoint
def get_qformer(model):
print("loading qformer")
qformer = qformer_model_from_original_config()
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
print("done loading qformer")
return qformer
def load_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile(delete=False) as file:
torch.save(checkpoint, file.name)
del checkpoint
model.load_state_dict(torch.load(file.name), strict=False)
os.remove(file.name)
def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
blip_diffusion.save_pretrained(args.checkpoint_path)
def main(args):
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
save_blip_diffusion_model(model.state_dict(), args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
@@ -35,12 +35,6 @@ if __name__ == "__main__":
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--config_files",
default=None,
type=str,
help="The YAML config file corresponding to the architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
+2 -1
View File
@@ -256,7 +256,7 @@ setup(
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.8.0",
python_requires=">=3.7.0",
install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
@@ -268,6 +268,7 @@ setup(
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
+1 -10
View File
@@ -3,7 +3,6 @@ __version__ = "0.22.0.dev0"
from typing import TYPE_CHECKING
from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
@@ -198,8 +197,6 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"IFImg2ImgPipeline",
@@ -369,7 +366,6 @@ else:
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
"FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler",
@@ -397,7 +393,6 @@ else:
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
]
)
@@ -415,7 +410,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipelines"].extend(["MidiProcessor"])
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .configuration_utils import ConfigMixin
try:
@@ -463,8 +458,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
BlipDiffusionControlNetPipeline,
BlipDiffusionPipeline,
CLIPImageProjection,
ConsistencyModelPipeline,
DanceDiffusionPipeline,
@@ -676,7 +669,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
@@ -695,7 +687,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
)
try:
+2 -2
View File
@@ -16,7 +16,7 @@ import warnings
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from PIL import Image
@@ -48,7 +48,7 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
+118 -205
View File
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
from collections import defaultdict
@@ -24,7 +23,6 @@ import requests
import safetensors
import torch
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
@@ -32,15 +30,11 @@ from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
)
from .utils.import_utils import BACKENDS_MAPPING
@@ -67,21 +61,6 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
@@ -1098,7 +1077,6 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
@@ -1290,7 +1268,6 @@ class LoraLoaderMixin:
state_dict = pretrained_model_name_or_path_or_dict
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
@@ -1543,35 +1520,55 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if cls.use_peft_backend:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
if network_alphas is not None:
alpha_keys = [
@@ -1581,79 +1578,56 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if cls.use_peft_backend:
from peft import LoraConfig
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@@ -1671,27 +1645,10 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if self.use_peft_backend:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if self.use_peft_backend:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
@@ -1718,7 +1675,6 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -1922,7 +1878,7 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
if "emb" in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
@@ -1934,13 +1890,6 @@ class LoraLoaderMixin:
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
@@ -2093,38 +2042,24 @@ class LoraLoaderMixin:
if fuse_unet:
self.unet.fuse_lora(lora_scale)
if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder, lora_scale)
fuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
fuse_text_encoder_lora(self.text_encoder_2)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
@@ -2146,29 +2081,18 @@ class LoraLoaderMixin:
if unfuse_unet:
self.unet.unfuse_lora()
if self.use_peft_backend:
from peft.tuners.tuner_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
@@ -2879,16 +2803,5 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
+2 -2
View File
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {}
@@ -43,7 +43,7 @@ if is_flax_available():
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
+3 -6
View File
@@ -252,10 +252,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else:
raise ValueError(
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
"'full_adapter_xl' or 'light_adapter'."
)
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
return self.adapter(x)
@@ -334,8 +331,8 @@ class FullAdapterXL(nn.Module):
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
self.body = nn.ModuleList(self.body)
# XL has only one downsampling AdapterBlock.
self.total_downscale_factor = downscale_factor * 2
# XL has one fewer downsampling
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x)
+2 -3
View File
@@ -19,7 +19,6 @@ import torch
from torch import nn
from .activations import get_activation
from .lora import LoRACompatibleLinear
def get_timestep_embedding(
@@ -167,7 +166,7 @@ class TimestepEmbedding(nn.Module):
):
super().__init__()
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -180,7 +179,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
+11 -18
View File
@@ -25,25 +25,18 @@ from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module):
@@ -42,25 +42,9 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
+17 -17
View File
@@ -303,23 +303,23 @@ class FlaxModelMixin(PushToHubMixin):
"framework": "flax",
}
# Load config if we don't provide one
if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path
model, model_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
# model args
dtype=dtype,
**kwargs,
)
# Load model
pretrained_path_with_subfolder = (
+3 -6
View File
@@ -52,7 +52,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
@@ -73,7 +72,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -193,7 +192,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
@@ -215,7 +213,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -333,7 +331,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
# there is always at least one resnet
@@ -353,7 +350,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
+1 -3
View File
@@ -883,6 +883,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
@@ -945,9 +946,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
print("emb", emb.abs().sum())
print("sample", sample.abs().sum())
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
+2 -80
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
@@ -116,11 +116,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
@@ -132,27 +127,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
num_micro_conditions = 5 if is_refiner else 6
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self):
block_out_channels = self.block_out_channels
@@ -193,24 +168,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == "text_time":
if self.addition_time_embed_dim is None:
raise ValueError(
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
)
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
# down
down_blocks = []
output_channel = block_out_channels[0]
@@ -225,7 +182,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
@@ -251,7 +207,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
@@ -263,7 +218,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
@@ -277,7 +231,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
@@ -316,7 +269,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample,
timesteps,
encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
return_dict: bool = True,
@@ -348,40 +300,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# additional embeddings
aug_emb = None
if self.addition_embed_type == "text_time":
if added_cond_kwargs is None:
raise ValueError(
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)
t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
if not isinstance(t_emb, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
import torch; import numpy as np
print("t_emb", torch.from_numpy(np.asarray(t_emb)).abs().sum())
print("sample", torch.from_numpy(np.asarray(sample)).abs().sum())
# 3. down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
+8 -24
View File
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
@@ -17,7 +16,7 @@ from ..utils import (
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
try:
if not is_torch_available():
@@ -68,10 +67,8 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
@@ -143,14 +140,12 @@ else:
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_xl"].extend(
[
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
)
_import_structure["stable_diffusion_xl"] = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
@@ -203,7 +198,6 @@ else:
"StableDiffusionOnnxPipeline",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
@@ -238,11 +232,6 @@ else:
"FlaxStableDiffusionPipeline",
]
)
_import_structure["stable_diffusion_xl"].extend(
[
"FlaxStableDiffusionXLPipeline",
]
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
@@ -253,7 +242,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
@@ -292,9 +281,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
@@ -446,9 +433,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
from .stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
@@ -27,7 +26,7 @@ else:
_import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
@@ -29,8 +29,7 @@ from ...utils import deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_output import AltDiffusionPipelineOutput
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -304,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -16,7 +16,7 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
@@ -31,8 +31,7 @@ from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docs
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_output import AltDiffusionPipelineOutput
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -302,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import PIL
from ...utils import (
BaseOutput,
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ...utils import _LazyModule
_import_structure = {
@@ -8,7 +8,7 @@ _import_structure = {
"pipeline_audio_diffusion": ["AudioDiffusionPipeline"],
}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .mel import Mel
from .pipeline_audio_diffusion import AudioDiffusionPipeline
+1 -2
View File
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
@@ -26,7 +25,7 @@ else:
_import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
@@ -26,7 +25,7 @@ else:
_import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
@@ -1,20 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
from .pipeline_blip_diffusion import BlipDiffusionPipeline
@@ -1,318 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for BLIP."""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from diffusers.utils import numpy_to_pil
if is_vision_available():
import PIL.Image
logger = logging.get_logger(__name__)
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
class BlipImageProcessor(BaseImageProcessor):
r"""
Constructs a BLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
do_center_crop: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_center_crop = do_center_crop
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if do_center_crop:
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
# Follows diffusers.VaeImageProcessor.postprocess
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
@@ -1,642 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Encoder,
Blip2PreTrainedModel,
Blip2QFormerAttention,
Blip2QFormerIntermediate,
Blip2QFormerOutput,
)
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.utils import (
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__)
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
# But it doesn't support getting multimodal embeddings. So, this module can be
# replaced with a future `transformers` version supports that.
class Blip2TextEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
batch_size = embeddings.shape[0]
# repeat the query embeddings for batch size
query_embeds = query_embeds.repeat(batch_size, 1, 1)
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = embeddings.to(query_embeds.dtype)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
class Blip2VisionEmbeddings(nn.Module):
def __init__(self, config: Blip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
class Blip2QFormerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if layer_module.has_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# The layers making up the Qformer encoder
class Blip2QFormerLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = Blip2QFormerIntermediate(config)
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
self.output = Blip2QFormerOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
class ProjLayer(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
super().__init__()
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
self.dense1 = nn.Linear(in_dim, hidden_dim)
self.act_fn = QuickGELU()
self.dense2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(drop_p)
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
def forward(self, x):
x_in = x
x = self.LayerNorm(x)
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
return x
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
class Blip2VisionModel(Blip2PreTrainedModel):
main_input_name = "pixel_values"
config_class = Blip2VisionConfig
def __init__(self, config: Blip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Blip2VisionEmbeddings(config)
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = Blip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layernorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
# Qformer model, used to get multimodal embeddings from the text and image inputs
class Blip2QFormerModel(Blip2PreTrainedModel):
"""
Querying Transformer (Q-Former), used in BLIP-2.
"""
def __init__(self, config: Blip2Config):
super().__init__(config)
self.config = config
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
self.visual_encoder = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
if not hasattr(config, "tokenizer") or config.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
else:
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.proj_layer = ProjLayer(
in_dim=config.qformer_config.hidden_size,
out_dim=config.qformer_config.hidden_size,
hidden_dim=config.qformer_config.hidden_size * 4,
drop_p=0.1,
eps=1e-12,
)
self.encoder = Blip2QFormerEncoder(config.qformer_config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int],
device: torch.device,
has_query: bool = False,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
text_input=None,
image_input=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
text = text.to(self.device)
input_ids = text.input_ids
batch_size = input_ids.shape[0]
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
)
query_length = self.query_tokens.shape[1]
embedding_output = self.embeddings(
input_ids=input_ids,
query_embeds=self.query_tokens,
past_key_values_length=past_key_values_length,
)
# embedding_output = self.layernorm(query_embeds)
# embedding_output = self.dropout(embedding_output)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
encoder_hidden_states = image_embeds_frozen
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if isinstance(encoder_attention_mask, list):
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = sequence_output[:, 0, :]
if not return_dict:
return self.proj_layer(sequence_output[:, :query_length, :])
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@@ -1,212 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import (
CLIPEncoder,
_expand_mask,
)
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
class ContextCLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = ContextCLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
ctx_embeddings: torch.Tensor = None,
ctx_begin_pos: list = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return self.text_model(
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class ContextCLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = ContextCLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
)
bsz, seq_len = input_shape
if ctx_embeddings is not None:
seq_len += ctx_embeddings.size(1)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
input_ids.to(torch.int).argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class ContextCLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if ctx_embeddings is None:
ctx_len = 0
else:
ctx_len = ctx_embeddings.shape[1]
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
# for each input embeddings, add the ctx embeddings at the correct position
input_embeds_ctx = []
bsz = inputs_embeds.shape[0]
if ctx_embeddings is not None:
for i in range(bsz):
cbp = ctx_begin_pos[i]
prefix = inputs_embeds[i, :cbp]
# remove the special token embedding
suffix = inputs_embeds[i, cbp:]
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
@@ -1,348 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL.Image
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionPipeline
>>> from diffusers.utils import load_image
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
... ).to("cuda")
>>> cond_subject = "dog"
>>> tgt_subject = "dog"
>>> text_prompt_input = "swimming underwater"
>>> cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 25
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt_input,
... cond_image,
... cond_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionPipeline(DiffusionPipeline):
"""
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
)
_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .pipeline_consistency_models import ConsistencyModelPipeline
else:
+77 -80
View File
@@ -1,80 +1,77 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -34,7 +34,7 @@ from ...utils import (
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .multicontrolnet import MultiControlNetModel
@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1,413 +0,0 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL.Image
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
>>> from diffusers.utils import load_image
>>> from controlnet_aux import CannyDetector
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
... ).to("cuda")
>>> style_subject = "flower"
>>> tgt_subject = "teapot"
>>> text_prompt = "on a marble table"
>>> cldm_cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
... ).resize((512, 512))
>>> canny = CannyDetector()
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
>>> style_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 50
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt,
... style_image,
... cldm_cond_image,
... style_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
controlnet ([`ControlNetModel`]):
ControlNet model to get the conditioning image embedding.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
controlnet: ControlNetModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
controlnet=controlnet,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.image_processor.preprocess(
image,
size={"width": width, "height": height},
do_rescale=True,
do_center_crop=False,
do_normalize=False,
return_tensors="pt",
)["pixel_values"].to(self.device)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
condtioning_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
condtioning_image (`PIL.Image.Image`):
The conditioning canny edge image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
seed (`int`, *optional*, defaults to 42):
The seed to use for random generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
cond_image = self.prepare_control_image(
image=condtioning_image,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=1,
device=self.device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=cond_image,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -16,11 +16,13 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
@@ -39,7 +41,6 @@ from ...utils import (
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .multicontrolnet import MultiControlNetModel
@@ -314,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -41,7 +41,7 @@ from ...utils import (
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -863,7 +863,7 @@ class StableDiffusionXLControlNetPipeline(
The percentage of total steps at which the ControlNet stops applying.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
@@ -873,7 +873,7 @@ class StableDiffusionXLControlNetPipeline(
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
@@ -41,7 +41,7 @@ from ...utils import (
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -1028,7 +1028,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
The percentage of total steps at which the controlnet stops applying.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
@@ -1038,7 +1038,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ...utils import _LazyModule
_import_structure = {"pipeline_dance_diffusion": ["DanceDiffusionPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .pipeline_dance_diffusion import DanceDiffusionPipeline
else:
import sys
+2 -2
View File
@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ...utils import _LazyModule
_import_structure = {"pipeline_ddim": ["DDIMPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .pipeline_ddim import DDIMPipeline
else:
import sys
+1 -2
View File
@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
)
_import_structure = {"pipeline_ddpm": ["DDPMPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .pipeline_ddpm import DDPMPipeline
else:
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
@@ -43,7 +42,7 @@ else:
_import_structure["watermark"] = ["IFWatermarker"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
@@ -20,7 +20,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -5,7 +5,7 @@ import urllib.parse as ul
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
@@ -23,7 +23,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -5,7 +5,7 @@ import urllib.parse as ul
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
@@ -24,7 +24,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -5,7 +5,7 @@ import urllib.parse as ul
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
@@ -23,7 +23,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -5,7 +5,7 @@ import urllib.parse as ul
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
@@ -24,7 +24,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -5,7 +5,7 @@ import urllib.parse as ul
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
@@ -23,7 +23,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import IFPipelineOutput
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import PIL
from ...utils import BaseOutput
@@ -1,6 +1,6 @@
from typing import List
import PIL.Image
import PIL
import torch
from PIL import Image
+2 -2
View File
@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ...utils import _LazyModule
_import_structure = {"pipeline_dit": ["DiTPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
from .pipeline_dit import DiTPipeline
else:
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
@@ -33,7 +32,7 @@ else:
_import_structure["text_encoder"] = ["MultilingualCLIP"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Callable, List, Optional, Union
import PIL.Image
import PIL
import torch
from transformers import (
CLIPImageProcessor,
@@ -15,7 +15,7 @@
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
from PIL import Image
from transformers import (
@@ -16,7 +16,7 @@ from copy import deepcopy
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import PIL
import torch
import torch.nn.functional as F
from packaging import version

Some files were not shown because too many files have changed in this diff Show More