Compare commits

..

2 Commits

Author SHA1 Message Date
Patrick von Platen 434dab4a2f make style 2023-12-18 13:15:00 +00:00
Patrick von Platen dbcbfb3118 [SVD] Fix guidance scale 2023-11-30 16:17:53 +00:00
156 changed files with 1418 additions and 11312 deletions
+10 -5
View File
@@ -1,6 +1,12 @@
name: Fast tests for PRs - Test Fetcher
on: workflow_dispatch
on:
pull_request:
branches:
- main
push:
branches:
- ci-*
env:
DIFFUSERS_IS_CI: yes
@@ -29,15 +35,14 @@ jobs:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 0
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install -e .
- name: Environment
run: |
python utils/print_env.py
echo $(git --version)
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
@@ -105,7 +110,7 @@ jobs:
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
-1
View File
@@ -113,7 +113,6 @@ jobs:
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m pip install peft
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
+1 -1
View File
@@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)):
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
+1 -1
View File
@@ -41,7 +41,7 @@ repo-consistency:
quality:
ruff check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing
-4
View File
@@ -264,10 +264,6 @@
title: ControlNet
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/cycle_diffusion
title: Cycle Diffusion
- local: api/pipelines/dance_diffusion
-3
View File
@@ -20,9 +20,6 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor
@@ -1,39 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory.
Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):
*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionControlNetXSPipeline
[[autodoc]] StableDiffusionControlNetXSPipeline
- all
- __call__
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
@@ -1,45 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory.
Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):
*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
<Tip warning={true}>
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
</Tip>
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionXLControlNetXSPipeline
[[autodoc]] StableDiffusionXLControlNetXSPipeline
- all
- __call__
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
-3
View File
@@ -40,8 +40,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Consistency Models](consistency_models) | unconditional image generation |
| [ControlNet](controlnet) | text2image, image2image, inpainting |
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cycle Diffusion](cycle_diffusion) | image2image |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
@@ -73,7 +71,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
| [Stable Diffusion Model Editing](model_editing) | model editing |
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
| [Stable Diffusion XL Turbo](stable_diffusion/sdxl_turbo) | text2image, image2image, inpainting |
| [Stable unCLIP](stable_unclip) | text2image, image variation |
| [Stochastic Karras VE](stochastic_karras_ve) | unconditional image generation |
| [T2I-Adapter](stable_diffusion/adapter) | text2image |
@@ -20,7 +20,7 @@ The abstract from the paper is:
## Tips
- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl), which means it also has the same API. Please refer to the [SDXL](./stable_diffusion_xl) API reference for more details.
- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl).
- SDXL Turbo should disable guidance scale by setting `guidance_scale=0.0`
- SDXL Turbo should use `timestep_spacing='trailing'` for the scheduler and use between 1 and 4 steps.
- SDXL Turbo has been trained to generate images of size 512x512.
@@ -28,8 +28,26 @@ The abstract from the paper is:
<Tip>
To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [SDXL Turbo](../../../using-diffusers/sdxl_turbo) guide.
To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](../../../using-diffusers/sdxl_turbo) guide.
Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
</Tip>
## StableDiffusionXLPipeline
[[autodoc]] StableDiffusionXLPipeline
- all
- __call__
## StableDiffusionXLImg2ImgPipeline
[[autodoc]] StableDiffusionXLImg2ImgPipeline
- all
- __call__
## StableDiffusionXLInpaintPipeline
[[autodoc]] StableDiffusionXLInpaintPipeline
- all
- __call__
@@ -24,7 +24,7 @@ The abstract from the paper is:
*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
+7 -27
View File
@@ -297,37 +297,17 @@ if you don't know yet what specific component you would like to add:
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
original author directly on the PR so that they can follow the progress and potentially help with questions.
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
#### Copied from mechanism
A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.
For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.
```py
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
class AltDiffusionPipelineOutput(BaseOutput):
"""
Output class for Alt Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
"""
```
To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.
## How to write a good issue
**The better your issue is written, the higher the chances that it will be quickly resolved.**
@@ -20,8 +20,6 @@ The Kandinsky models are a series of multilingual text-to-image generation model
[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.
This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
Before you begin, make sure you have the following libraries installed:
@@ -35,10 +33,6 @@ Before you begin, make sure you have the following libraries installed:
Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
<br>
Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
</Tip>
## Text-to-image
@@ -97,23 +91,6 @@ image
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png"/>
</div>
</hfoption>
<hfoption id="Kandinsky 3">
Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:
```py
from diffusers import Kandinsky3Pipeline
import torch
pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
image = pipeline(prompt).images[0]
image
```
</hfoption>
</hfoptions>
@@ -184,20 +161,6 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
```
</hfoption>
<hfoption id="Kandinsky 3">
Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:
```py
from diffusers import Kandinsky3Img2ImgPipeline
from diffusers.utils import load_image
import torch
pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
```
</hfoption>
</hfoptions>
@@ -255,14 +218,6 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png"/>
</div>
</hfoption>
<hfoption id="Kandinsky 3">
```py
image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]
image
```
</hfoption>
</hfoptions>
@@ -485,69 +485,6 @@ image.save("sdxl_t2i.png")
</div>
</div>
You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations.
Weights are loaded with the same method used for the other IP-Adapters.
```python
# Load ip-adapter-full-face_sd15.bin
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
```
<Tip>
It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model.
</Tip>
```python
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1
)
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
scheduler=noise_scheduler,
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50, num_images_per_prompt=1, width=512, height=704,
generator=generator,
).images[0]
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ipadapter_full_face_output.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
</div>
### LCM-Lora
@@ -174,4 +174,10 @@ Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] functi
controlnet.push_to_hub("my-controlnet-model-private", private=True)
```
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.`
To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
```py
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
```
+2 -3
View File
@@ -53,9 +53,8 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video controls width="1024" height="576">
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.webm" type="video/webm" />
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
<video width="1024" height="576" controls>
<source src="https://i.imgur.com/jJzVDKw.mp4" type="video/mp4">
</video>
<Tip>
@@ -54,7 +54,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
@@ -62,51 +62,16 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
@@ -118,34 +83,12 @@ def save_model_card(
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
url: >-
"image_{i}.png"
"""
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
"""
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
yaml = f"""---
yaml = f"""
---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
@@ -153,14 +96,14 @@ tags:
- diffusers
- lora
- template:sd-lora
widget:
{img_str}
---
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""
"""
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
@@ -169,44 +112,20 @@ widget:
## Model description
### These are {repo_id} LoRA adaption weights for {base_model}.
These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
## Trigger words
{trigger_str}
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
You should use {instance_prompt} to trigger the image generation.
## Download model
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
Weights for this model are available in Safetensors format.
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
All [Files & versions](/{repo_id}/tree/main).
## Details
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
LoRA for the text encoder was enabled. {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
@@ -255,12 +174,6 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -268,26 +181,20 @@ def parse_args(input_args=None):
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
"datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
"https://huggingface.co/docs/datasets/image_dataset for more information"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
"if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
"None if there's only one config.",
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help="A path to local folder containing the training data of instance images. Specify this arg instead of "
"--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
"--dataset_name instead.",
help=("A folder containing the training data. "),
)
parser.add_argument(
@@ -330,18 +237,15 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--token_abstraction",
type=str,
default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
"'TOK,TOK2,TOK3' etc.",
"captions - e.g. TOK",
)
parser.add_argument(
"--num_new_tokens_per_abstraction",
type=int,
default=2,
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
"tokens - <si><si+1> ",
)
@@ -551,7 +455,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--train_text_encoder_frac",
type=float,
default=1.0,
default=0.5,
help=("The percentage of epochs to perform text encoder tuning"),
)
@@ -584,7 +488,7 @@ def parse_args(input_args=None):
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
@@ -673,12 +577,6 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -698,6 +596,17 @@ def parse_args(input_args=None):
"inversion training check `--train_text_encoder_ti`"
)
if args.train_text_encoder_ti:
if isinstance(args.token_abstraction, str):
args.token_abstraction = [args.token_abstraction]
elif isinstance(args.token_abstraction, List):
args.token_abstraction = args.token_abstraction
else:
raise ValueError(
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
)
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -770,19 +679,12 @@ class TokenEmbeddingsHandler:
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
# Note: When loading with diffusers, any name can work - simply specify in inference
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
tensors[f"text_encoders_{idx}"] = new_token_embeddings
save_file(tensors, file_path)
@@ -794,6 +696,19 @@ class TokenEmbeddingsHandler:
def device(self):
return self.text_encoders[0].device
# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
# # Assuming new tokens are of the format <s_i>
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
# tokenizer.add_special_tokens(special_tokens_dict)
# text_encoder.resize_token_embeddings(len(tokenizer))
#
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
# text_encoder.text_model.embeddings.token_embedding.weight.data[
# self.train_ids
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
@@ -815,6 +730,15 @@ class TokenEmbeddingsHandler:
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
# def load_embeddings(self, file_path: str):
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
# for idx in range(len(self.text_encoders)):
# text_encoder = self.text_encoders[idx]
# tokenizer = self.tokenizers[idx]
#
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
class DreamBoothDataset(Dataset):
"""
@@ -827,12 +751,6 @@ class DreamBoothDataset(Dataset):
instance_data_root,
instance_prompt,
class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti,
class_data_root=None,
class_num=None,
token_abstraction_dict=None, # token mapping for textual inversion
@@ -847,10 +765,10 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict
self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if dataset_name is not None:
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
@@ -863,25 +781,26 @@ class DreamBoothDataset(Dataset):
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
dataset_name,
dataset_config_name,
cache_dir=cache_dir,
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if image_column is None:
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if caption_column is None:
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
@@ -889,11 +808,11 @@ class DreamBoothDataset(Dataset):
)
self.custom_instance_prompts = None
else:
if caption_column not in column_names:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][caption_column]
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
@@ -948,7 +867,7 @@ class DreamBoothDataset(Dataset):
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
if self.train_text_encoder_ti:
if args.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement))
@@ -1102,7 +1021,6 @@ def main(args):
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -1134,25 +1052,17 @@ def main(args):
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
model_id = args.hub_model_id or Path(args.output_dir).name
repo_id = None
if args.push_to_hub:
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
variant=args.variant,
use_fast=False,
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -1166,10 +1076,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -1177,25 +1087,16 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
vae_scaling_factor = vae.config.scaling_factor
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.train_text_encoder_ti:
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
# TOK2" -> ["TOK", "TOK2"] etc.
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
logger.info(f"list of token identifiers: {token_abstraction_list}")
token_abstraction_dict = {}
token_idx = 0
for i, token in enumerate(token_abstraction_list):
for i, token in enumerate(args.token_abstraction):
token_abstraction_dict[token] = [
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
]
@@ -1315,8 +1216,6 @@ def main(args):
text_lora_parameters_one = []
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1324,8 +1223,6 @@ def main(args):
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1412,16 +1309,12 @@ def main(args):
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
@@ -1506,12 +1399,6 @@ def main(args):
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images,
@@ -1607,26 +1494,6 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=torch.float32
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1726,10 +1593,27 @@ def main(args):
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
params_to_optimize = params_to_optimize[:1]
# reinitializing the optimizer to optimize only on unet params
if args.optimizer.lower() == "prodigy":
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
else: # AdamW or 8-bit-AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
# still optimizng the text encoder
text_encoder_one.train()
@@ -1742,7 +1626,9 @@ def main(args):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
@@ -1754,13 +1640,9 @@ def main(args):
tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
@@ -1919,18 +1801,12 @@ def main(args):
f" {args.validation_prompt}."
)
# create pipeline
if freeze_text_encoder:
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1939,7 +1815,6 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -2017,15 +1892,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -2068,23 +1938,21 @@ def main(args):
}
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
model_id if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
if args.push_to_hub:
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
+5 -326
View File
@@ -48,10 +48,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -79,7 +77,6 @@ from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"longlian/lmd_plus",
custom_pipeline="llm_grounded_diffusion",
custom_revision="main",
variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
@@ -513,6 +510,7 @@ device = torch.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
use_auth_token=True,
custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
@@ -552,6 +550,7 @@ device = th.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="seed_resize_stable_diffusion"
).to(device)
@@ -587,6 +586,7 @@ generator = th.Generator("cuda").manual_seed(0)
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device)
@@ -605,6 +605,7 @@ image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=heigh
pipe_compare = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device)
@@ -2523,181 +2524,6 @@ images[0].save("controlnet_and_adapter_inpaint.png")
```
### Regional Prompting Pipeline
This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers.
This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png)
### Usage
### Sample Code
```
from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
prompt ="""
green hair twintail BREAK
red blouse BREAK
blue skirt
"""
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=7.5,
height = 768,
width = 512,
num_inference_steps =20,
num_images_per_prompt = 1,
rp_args = rp_args
).images
time = time.strftime(r"%Y%m%d%H%M%S")
i = 1
for image in images:
i += 1
fileName = f'img-{time}-{i+1}.png'
image.save(fileName)
```
### Cols, Rows mode
In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png)
```
green hair twintail BREAK
red blouse BREAK
blue skirt
```
### 2-Dimentional division
The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
```
rp_args = {
"mode":"rows",
"div": "1,2,1,1;2,4,6"
}
prompt ="""
blue sky BREAK
green hair BREAK
book shelf BREAK
terrarium on desk BREAK
orange dress and sofa
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png)
### Prompt Mode
There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
### syntax
```
baseprompt target1 target2 BREAK
effect1, target1 BREAK
effect2 ,target2
```
First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
```
target2 baseprompt target1 BREAK
effect1, target1 BREAK
effect2 ,target2
```
is also effective.
### Sample
In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
```
rp_args = {
"mode":"prompt-ex",
"save_mask":True,
"th": "0.4,0.6,0.6",
}
prompt ="""
a girl in street with shirt, tie, skirt BREAK
red, shirt BREAK
green, tie BREAK
blue , skirt
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png)
### threshold
The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
```
a lady ,hair, face BREAK
red, hair BREAK
tanned ,face
```
`threshold : 0.4,0.6`
If only one input is given for multiple regions, they are all assumed to be the same value.
### Prompt and Prompt-EX
The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
### Accuracy
In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
```
girl hair twintail frills,ribbons, dress, face BREAK
girl, ,face
```
### Mask
When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
### Use common prompt
You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows:
```
best quality, 3persons in garden, ADDCOMM
a girl white dress BREAK
a boy blue shirt BREAK
an old man red suit
```
If common is enabled, this prompt is converted to the following:
```
best quality, 3persons in garden, a girl white dress BREAK
best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```
### Negative prompt
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
### Parameters
To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type.
### Input Parameters
Parameters are specified through the `rp_arg`(dictionary type).
```
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
pipe(prompt =prompt, rp_args = rp_args)
```
### Required Parameters
- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive.
- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
### Optional Parameters
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@@ -2839,150 +2665,3 @@ The Pipeline supports `compel` syntax. Input prompts using the `compel` structur
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
### AnimateDiff ControlNet Pipeline
This pipeline combines AnimateDiff and ControlNet. Enjoy precise motion control for your videos! Refer to [this](https://github.com/huggingface/diffusers/issues/5866) issue for more details.
```py
import torch
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler
from PIL import Image
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
adapter = MotionAdapter.from_pretrained(motion_id)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = DiffusionPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
controlnet=controlnet,
vae=vae,
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
)
pipe.enable_vae_slicing()
conditioning_frames = []
for i in range(1, 16 + 1):
conditioning_frames.append(Image.open(f"frame_{i}.png"))
prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
conditioning_frames=conditioning_frames,
num_inference_steps=12,
).frames[0]
from diffusers.utils import export_to_gif
export_to_gif(result.frames[0], "result.gif")
```
<table>
<tr><td colspan="2" align=center><b>Conditioning Frames</b></td></tr>
<tr align=center>
<td align=center><img src="https://user-images.githubusercontent.com/7365912/265043418-23291941-864d-495a-8ba8-d02e05756396.gif" alt="input-frames"></td>
</tr>
<tr><td colspan="2" align=center><b>AnimateDiff model: SG161222/Realistic_Vision_V5.1_noVAE</b></td></tr>
<tr>
<td align=center><img src="https://github.com/huggingface/diffusers/assets/72266394/baf301e2-d03c-4129-bd84-203a1de2b2be" alt="gif-1"></td>
<td align=center><img src="https://github.com/huggingface/diffusers/assets/72266394/9f923475-ecaf-452b-92c8-4e42171182d8" alt="gif-2"></td>
</tr>
<tr><td colspan="2" align=center><b>AnimateDiff model: CardosAnime</b></td></tr>
<tr>
<td align=center><img src="https://github.com/huggingface/diffusers/assets/72266394/b2c41028-38a0-45d6-86ed-fec7446b87f7" alt="gif-1"></td>
<td align=center><img src="https://github.com/huggingface/diffusers/assets/72266394/eb7d2952-72e4-44fa-b664-077c79b4fc70" alt="gif-2"></td>
</tr>
</table>
### DemoFusion
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
- `view_batch_size` (`int`, defaults to 16):
The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.
- `stride` (`int`, defaults to 64):
The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.
- `cosine_scale_1` (`float`, defaults to 3):
Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `cosine_scale_2` (`float`, defaults to 1):
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `cosine_scale_3` (`float`, defaults to 1):
Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `sigma` (`float`, defaults to 1):
The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.
- `multi_decoder` (`bool`, defaults to True):
Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.
- `show_image` (`bool`, defaults to False):
Determine whether to show intermediate results during generation.
```
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="pipeline_demofusion_sdxl",
custom_revision="main",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
images = pipe(
prompt,
negative_prompt=negative_prompt,
height=3072,
width=3072,
view_batch_size=16,
stride=64,
num_inference_steps=50,
guidance_scale=7.5,
cosine_scale_1=3,
cosine_scale_2=1,
cosine_scale_3=1,
sigma=0.8,
multi_decoder=True,
show_image=True
)
```
You can display and save the generated images as:
```
def image_grid(imgs, save_path=None):
w = 0
for i, img in enumerate(imgs):
h_, w_ = imgs[i].size
w += w_
h = h_
grid = Image.new('RGB', size=(w, h))
grid_w, grid_h = grid.size
w = 0
for i, img in enumerate(imgs):
h_, w_ = imgs[i].size
grid.paste(img, box=(w, h - h_))
if save_path != None:
img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024))
w += w_
return grid
image_grid(images, save_path="./outputs/")
```
![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png)
+6 -8
View File
@@ -5,11 +5,10 @@ from typing import Dict, List, Union
import safetensors.torch
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
class CheckpointMergerPipeline(DiffusionPipeline):
@@ -58,7 +57,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
return (temp_dict, meta_keys)
@torch.no_grad()
@validate_hf_hub_args
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
"""
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
@@ -71,7 +69,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
**kwargs:
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
@@ -83,12 +81,12 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None)
@@ -125,7 +123,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
config_dicts.append(config_dict)
@@ -161,7 +159,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,
+21 -619
View File
@@ -16,7 +16,6 @@
import ast
import gc
import inspect
import math
import warnings
from collections.abc import Iterable
@@ -24,29 +23,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import logging, replace_example_docstring
EXAMPLE_DOC_STRING = """
@@ -58,7 +44,6 @@ EXAMPLE_DOC_STRING = """
>>> pipe = DiffusionPipeline.from_pretrained(
... "longlian/lmd_plus",
... custom_pipeline="llm_grounded_diffusion",
... custom_revision="main",
... variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
@@ -111,12 +96,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
DEFAULT_GUIDANCE_ATTN_KEYS = [
("mid", 0, 0, 0),
("up", 1, 0, 0),
("up", 1, 1, 0),
("up", 1, 2, 0),
]
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
def convert_attn_keys(key):
@@ -146,15 +126,7 @@ def scale_proportion(obj_box, H, W):
# Adapted from the parent class `AttnProcessor2_0`
class AttnProcessorWithHook(AttnProcessor2_0):
def __init__(
self,
attn_processor_key,
hidden_size,
cross_attention_dim,
hook=None,
fast_attn=True,
enabled=True,
):
def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
super().__init__()
self.attn_processor_key = attn_processor_key
self.hidden_size = hidden_size
@@ -193,16 +165,15 @@ class AttnProcessorWithHook(AttnProcessor2_0):
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -215,13 +186,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
if self.hook is not None and self.enabled:
# Call the hook with query, key, value, and attention maps
self.hook(
self.attn_processor_key,
query_batch_dim,
key_batch_dim,
value_batch_dim,
attention_probs,
)
self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
if self.fast_attn:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -237,12 +202,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -251,7 +211,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -266,9 +226,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
return hidden_states
class LLMGroundedDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
r"""
Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
@@ -299,11 +257,6 @@ class LLMGroundedDiffusionPipeline(
Whether a safety checker is needed for this pipeline.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
objects_text = "Objects: "
bg_prompt_text = "Background prompt: "
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
@@ -319,91 +272,12 @@ class LLMGroundedDiffusionPipeline(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Initialize the attention hooks for LLM-grounded Diffusion
self.register_attn_hooks(unet)
self._saved_attn = None
@@ -590,14 +464,7 @@ class LLMGroundedDiffusionPipeline(
return token_map
def get_phrase_indices(
self,
prompt,
phrases,
token_map=None,
add_suffix_if_not_found=False,
verbose=False,
):
def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
for obj in phrases:
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
if obj not in prompt:
@@ -618,14 +485,7 @@ class LLMGroundedDiffusionPipeline(
phrase_token_map_str = " ".join(phrase_token_map)
if verbose:
logger.info(
"Full str:",
token_map_str,
"Substr:",
phrase_token_map_str,
"Phrase:",
phrases,
)
logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
# Count the number of token before substr
# The substring comes with a trailing space that needs to be removed by minus one in the index.
@@ -692,15 +552,7 @@ class LLMGroundedDiffusionPipeline(
return loss
def compute_ca_loss(
self,
saved_attn,
bboxes,
phrase_indices,
guidance_attn_keys,
verbose=False,
**kwargs,
):
def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
"""
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
@@ -753,7 +605,6 @@ class LLMGroundedDiffusionPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -811,7 +662,6 @@ class LLMGroundedDiffusionPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -874,10 +724,9 @@ class LLMGroundedDiffusionPipeline(
phrase_indices = []
prompt_parsed = []
for prompt_item in prompt:
(
phrase_indices_parsed_item,
prompt_parsed_item,
) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
prompt_item, add_suffix_if_not_found=True
)
phrase_indices.append(phrase_indices_parsed_item)
prompt_parsed.append(prompt_parsed_item)
prompt = prompt_parsed
@@ -910,11 +759,6 @@ class LLMGroundedDiffusionPipeline(
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -957,10 +801,7 @@ class LLMGroundedDiffusionPipeline(
if n_objs:
cond_boxes[:n_objs] = torch.tensor(boxes)
text_embeddings = torch.zeros(
max_objs,
self.unet.config.cross_attention_dim,
device=device,
dtype=self.text_encoder.dtype,
max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
)
if n_objs:
text_embeddings[:n_objs] = _text_embeddings
@@ -992,9 +833,6 @@ class LLMGroundedDiffusionPipeline(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
loss_attn = torch.tensor(10000.0)
# 7. Denoising loop
@@ -1031,7 +869,6 @@ class LLMGroundedDiffusionPipeline(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
@@ -1176,438 +1013,3 @@ class LLMGroundedDiffusionPipeline(
self.enable_attn_hook(enabled=False)
return latents, loss
# Below are methods copied from StableDiffusionPipeline
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
@property
def guidance_scale(self):
return self._guidance_scale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
@property
def guidance_rescale(self):
return self._guidance_rescale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
@property
def num_timesteps(self):
return self._num_timesteps
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,589 +0,0 @@
import math
from typing import Dict, Optional
import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND
try:
from compel import Compel
except ImportError:
Compel = None
KCOMM = "ADDCOMM"
KBRK = "BREAK"
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Args for Regional Prompting Pipeline:
rp_args:dict
Required
rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
def getcompelembs(prps):
embl = []
for prp in prps:
embl.append(compel.build_conditioning_tensor(prp))
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
if not active:
pcallback = None
mode = None
else:
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells, icells, regions = make_cells(rp_args["div"])
elif "PRO" in rp_args["mode"].upper():
regions = len(all_prompts_p[0])
mode = "PROMPT"
reset_attnmaps(self)
self.ex = "EX" in rp_args["mode"].upper()
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
thresholds = [float(x) for x in rp_args["th"].split(",")]
orig_hw = (height, width)
revers = True
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
self.step = step
if len(self.attnmaps_sizes) > 3:
self.history[step] = self.attnmaps.copy()
for hw in self.attnmaps_sizes:
allmasks = []
basemasks = [None] * batch
for tt, th in zip(target_tokens, thresholds):
for b in range(batch):
key = f"{tt}-{b}"
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
mask = mask.unsqueeze(0).unsqueeze(-1)
if self.ex:
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
allmasks.append(mask)
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
basemasks = [1 - mask for mask in basemasks]
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
allmasks = basemasks + allmasks
self.attnmasks[hw] = torch.cat(allmasks)
self.maskready = True
return latents
def hook_forward(module):
# diffusers==0.23.2
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
attn = module
xshape = hidden_states.shape
self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
if revers:
nx, px = hidden_states.chunk(2)
else:
px, nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = scaled_dot_product_attention(
self,
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
getattn="PRO" in mode,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
#### Regional Prompting Col/Row mode
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px, nx] if cn else [px]
for out in outs:
c = 0
for i, ocell in enumerate(ocells):
for icell in icells[i]:
if "ROW" in mode:
out[
0:batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
]
else:
out[
0:batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
if (h, w) in self.attnmasks and self.maskready:
def mask(input):
out = torch.multiply(input, self.attnmasks[(h, w)])
for b in range(batch):
for r in range(1, regions):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states
return forward
def hook_forwards(root_module: torch.nn.Module):
for name, module in root_module.named_modules():
if "attn2" in name and module.__class__.__name__ == "Attention":
module.forward = hook_forward(module)
hook_forwards(self.unet)
output = StableDiffusionPipeline(**self.components)(
prompt=prompt,
prompt_embeds=embs,
negative_prompt=negative_prompt,
negative_prompt_embeds=n_embs,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end=pcallback,
)
if "save_mask" in rp_args:
save_mask = rp_args["save_mask"]
else:
save_mask = False
if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
return output
### Make prompt list for each regions
def promptsmaker(prompts, batch):
out_p = []
plen = len(prompts)
for prompt in prompts:
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
start = (p + r * plen) * batch
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return out, out_p
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
if ";" not in ratios and "," in ratios:
ratios = ratios.replace(",", ";")
ratios = ratios.split(";")
ratios = [inratios.split(",") for inratios in ratios]
icells = []
ocells = []
def startend(cells, array):
current_start = 0
array = [float(x) for x in array]
for value in array:
end = current_start + (value / sum(array))
cells.append([current_start, end])
current_start = end
startend(ocells, [r[0] for r in ratios])
for inratios in ratios:
if 2 > len(inratios):
icells.append([[0, 1]])
else:
add = []
startend(add, inratios[1:])
icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells)
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
def split_dims(xs, height, width):
xs = xs
def repeat_div(x, y):
while y > 0:
x = math.ceil(x / 2)
y = y - 1
return x
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
dsh = repeat_div(height, scale)
dsw = repeat_div(width, scale)
return dsh, dsw
##### for prompt mode
def get_attn_maps(self, attn):
height, width = self.hw
target_tokens = self.target_tokens
if (height, width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height, width))
for b in range(self.batch):
for t in target_tokens:
power = self.power
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
add = torch.sum(add, dim=2)
key = f"{t}-{b}"
if key not in self.attnmaps:
self.attnmaps[key] = add
else:
if self.attnmaps[key].shape[1] != add.shape[1]:
add = add.view(8, height, width)
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
add = add.reshape_as(self.attnmaps[key])
self.attnmaps[key] = self.attnmaps[key] + add
def reset_attnmaps(self): # init parameters in every batch
self.step = 0
self.attnmaps = {} # maked from attention maps
self.attnmaps_sizes = [] # height,width set of u-net blocks
self.attnmasks = {} # maked from attnmaps for regions
self.maskready = False
self.history = {}
def saveattnmaps(self, output, h, w, th, step, regions):
masks = []
for i, mask in enumerate(self.history[step].values()):
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
if self.ex:
masks = [x - mask for x in masks]
masks.append(mask)
if len(masks) == regions - 1:
output.images.extend([FF.to_pil_image(mask) for mask in masks])
masks = []
else:
output.images.append(img)
def makepmask(
self, mask, h, w, th, step
): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
if 0.05 >= th:
th = 0.05
mask = torch.mean(mask, dim=0)
mask = mask / mask.max().item()
mask = torch.where(mask > th, 1, 0)
mask = mask.float()
mask = mask.view(1, *self.attnmaps_sizes[0])
img = FF.to_pil_image(mask)
img = img.resize((w, h))
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
lmask = mask
mask = mask.reshape(h * w)
mask = torch.where(mask > 0.1, 1, 0)
return img, mask, lmask
def tokendealer(self, all_prompts):
for prompts in all_prompts:
targets = [p.split(",")[-1] for p in prompts[1:]]
tt = []
for target in targets:
ptokens = (
self.tokenizer(
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
ttokens = (
self.tokenizer(
target,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
tlist = []
for t in range(ttokens.shape[0] - 2):
for p in range(ptokens.shape[0]):
if ttokens[t + 1] == ptokens[p]:
tlist.append(p)
if tlist != []:
tt.append(tlist)
return tt
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if getattn:
get_attn_maps(self, attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
@@ -28,7 +28,6 @@ import PIL.Image
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -42,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -51,7 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -710,7 +709,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -726,15 +724,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.vae.forward = self.vae.decode
@@ -779,13 +769,12 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -797,7 +786,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)
@@ -28,7 +28,6 @@ import PIL.Image
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -42,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -52,7 +51,7 @@ from diffusers.pipelines.stable_diffusion import (
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -711,7 +710,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -727,15 +725,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.vae.forward = self.vae.decode
@@ -780,13 +770,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -798,7 +787,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)
@@ -27,7 +27,6 @@ import onnx_graphsurgeon as gs
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -41,7 +40,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -50,7 +49,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -625,7 +624,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
@@ -641,15 +639,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.vae.forward = self.vae.decode
@@ -692,13 +682,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models["vae"] = make_VAE(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -710,7 +699,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)
+14 -22
View File
@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
## Full model distillation
@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
@@ -46,16 +46,12 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
#### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
@@ -63,7 +59,7 @@ accelerate launch train_lcm_distill_sd_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -73,23 +69,19 @@ accelerate launch train_lcm_distill_sd_wds.py \
--resume_from_checkpoint=latest \
--report_to=wandb \
--seed=453645634 \
--push_to_hub
--push_to_hub \
```
## LCM-LoRA
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
@@ -98,7 +90,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
## Full model distillation
@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
@@ -46,16 +46,12 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
#### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
@@ -64,7 +60,7 @@ accelerate launch train_lcm_distill_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -81,15 +77,11 @@ accelerate launch train_lcm_distill_sdxl_wds.py \
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_lora_sdxl_wds.py \
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
@@ -100,7 +92,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
@@ -423,7 +423,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
@@ -1123,7 +1123,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text = batch
image, text, _, _ = batch
image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)
@@ -68,16 +68,11 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
@@ -151,10 +146,10 @@ class WebdatasetFilter:
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
@@ -185,7 +180,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
def transform(example):
# resize image
@@ -217,7 +212,7 @@ class Text2ImageDataset:
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.select(WebdatasetFilter(min_size=960)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
@@ -397,7 +392,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
@@ -400,7 +400,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
@@ -1106,7 +1106,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text = batch
image, text, _, _ = batch
image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)
@@ -67,16 +67,11 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
@@ -133,10 +128,10 @@ class WebdatasetFilter:
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
@@ -167,7 +162,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
def transform(example):
# resize image
@@ -199,7 +194,7 @@ class Text2ImageDataset:
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.select(WebdatasetFilter(min_size=960)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
@@ -419,7 +414,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
+1 -1
View File
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
-1
View File
@@ -44,7 +44,6 @@ write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
-1
View File
@@ -47,7 +47,6 @@ write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
-1
View File
@@ -4,4 +4,3 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
@@ -4,4 +4,3 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+100 -28
View File
@@ -16,6 +16,7 @@
import argparse
import copy
import gc
import itertools
import logging
import math
import os
@@ -34,8 +35,6 @@ from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
@@ -53,13 +52,20 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -858,19 +864,79 @@ def main(args):
text_encoder.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
unet.add_adapter(unet_lora_config)
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
# => 32 layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
text_encoder.add_adapter(text_lora_config)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
@@ -882,9 +948,9 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -944,10 +1010,11 @@ def main(args):
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -1190,7 +1257,12 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
params_to_clip = (
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -1313,19 +1385,19 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
if args.train_text_encoder:
if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else:
text_encoder_state_dict = None
text_encoder_lora_layers = None
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
# Final inference
@@ -34,8 +34,6 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
@@ -52,14 +50,15 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -1010,19 +1009,54 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
unet.add_adapter(unet_lora_config)
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
@@ -1035,11 +1069,11 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1096,12 +1130,6 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# Optimization parameters
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
@@ -1166,10 +1194,26 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
@@ -1615,13 +1659,13 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_layers = unet_lora_state_dict(unet)
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -420,7 +420,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
@@ -975,7 +975,7 @@ def main(args):
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
)
if args.controlnet_model_name_or_path:
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
-2
View File
@@ -32,8 +32,6 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
accelerate config
```
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Pokemon example
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
-1
View File
@@ -45,7 +45,6 @@ write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Training
-1
View File
@@ -5,4 +5,3 @@ datasets
ftfy
tensorboard
Jinja2
peft==0.7.0
@@ -5,4 +5,3 @@ ftfy
tensorboard
Jinja2
datasets
peft==0.7.0
@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
@@ -34,14 +34,13 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -49,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -480,20 +479,62 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
unet.add_adapter(unet_lora_config)
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -508,8 +549,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
@@ -534,7 +573,7 @@ def main():
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers,
unet_lora_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -661,8 +700,8 @@ def main():
)
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -794,7 +833,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers
params_to_clip = unet_lora_parameters
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -831,15 +870,6 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -896,13 +926,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
unet.save_attn_procs(args.output_dir)
if args.push_to_hub:
save_model_card(
@@ -16,6 +16,7 @@
"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
import argparse
import itertools
import logging
import math
import os
@@ -36,8 +37,6 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
@@ -51,6 +50,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -658,20 +658,53 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
unet.add_adapter(unet_lora_config)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
@@ -684,11 +717,11 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -759,13 +792,11 @@ def main(args):
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
params_to_optimize = (
params_to_optimize
+ list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
)
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -1097,7 +1128,12 @@ def main(args):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
params_to_clip = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -1193,21 +1229,20 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_layers = unet_attn_processors_state_dict(unet)
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+25 -6
View File
@@ -118,10 +118,9 @@ _deps = [
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ruff==0.1.5",
"ruff>=0.1.5,<=0.2",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
"scipy",
"onnx",
"regex!=2019.12.17",
@@ -207,7 +206,6 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
"compel",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
@@ -251,13 +249,13 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.24.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
license="Apache 2.0 License",
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)",
license="Apache",
author="The HuggingFace team",
author_email="patrick@huggingface.co",
url="https://github.com/huggingface/diffusers",
package_dir={"": "src"},
@@ -281,3 +279,24 @@ setup(
+ [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
)
# Release checklist
# 1. Change the version in __init__.py and setup.py.
# 2. Commit these changes with the message: "Release: Release"
# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for PyPI'"
# Push the tag to git: git push --tags origin main
# 4. Run the following commands in the top-level directory:
# python setup.py bdist_wheel
# python setup.py sdist
# 5. Upload the package to the PyPI test server first:
# twine upload dist/* -r pypitest
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
# 6. Check that you can install it in a virtualenv by running:
# pip install -i https://testpypi.python.org/pypi diffusers
# diffusers env
# diffusers test
# 7. Upload the final version to the actual PyPI:
# twine upload dist/* -r pypi
# 8. Add release notes to the tag in GitHub once everything is looking hunky-dory.
# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to main.
+1 -7
View File
@@ -1,4 +1,4 @@
__version__ = "0.25.0.dev0"
__version__ = "0.24.0.dev0"
from typing import TYPE_CHECKING
@@ -80,7 +80,6 @@ else:
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSModel",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
@@ -251,7 +250,6 @@ else:
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
@@ -275,7 +273,6 @@ else:
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
@@ -457,7 +454,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSModel,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
@@ -607,7 +603,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
@@ -631,7 +626,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
+11 -10
View File
@@ -19,7 +19,6 @@ Usage example:
import glob
import json
import warnings
from argparse import ArgumentParser, Namespace
from importlib import import_module
@@ -33,12 +32,12 @@ from . import BaseDiffusersCLICommand
def conversion_command_factory(args: Namespace):
if args.use_auth_token:
warnings.warn(
"The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
" handled automatically if user is logged in."
)
return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
return FP16SafetensorsCommand(
args.ckpt_id,
args.fp16,
args.use_safetensors,
args.use_auth_token,
)
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
@@ -63,7 +62,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
)
conversion_parser.set_defaults(func=conversion_command_factory)
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
self.ckpt_id = ckpt_id
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
@@ -76,6 +75,8 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
)
self.use_auth_token = use_auth_token
def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
raise ImportError(
@@ -86,7 +87,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
from huggingface_hub import create_commit
from huggingface_hub._commit_api import CommitOperationAdd
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
with open(model_index, "r") as f:
pipeline_class_name = json.load(f)["_class_name"]
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
@@ -95,7 +96,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
# here, but just to avoid any rough edge cases.
pipeline = pipeline_class.from_pretrained(
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
)
pipeline.save_pretrained(
self.local_ckpt_dir,
+8 -12
View File
@@ -27,16 +27,12 @@ from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import __version__
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
@@ -279,7 +275,6 @@ class ConfigMixin:
return cls.load_config(*args, **kwargs)
@classmethod
@validate_hf_hub_args
def load_config(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -316,7 +311,7 @@ class ConfigMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -334,11 +329,11 @@ class ConfigMixin:
A dictionary of all the parameters stored in a JSON configuration file.
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
@@ -381,7 +376,7 @@ class ConfigMixin:
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
@@ -390,7 +385,8 @@ class ConfigMixin:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `token` or log in with `huggingface-cli login`."
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
+1 -2
View File
@@ -30,10 +30,9 @@ deps = {
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ruff": "ruff==0.1.5",
"ruff": "ruff>=0.1.5,<=0.2",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"regex": "regex!=2019.12.17",
@@ -113,7 +113,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: verify deprecation of this kwarg
x = self.scheduler.step(prev_x, i, x)["prev_sample"]
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory (set the initial state)
x = self.reset_x0(x, conditions, self.action_dim)
+7 -7
View File
@@ -15,10 +15,11 @@ import os
from typing import Dict, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_transformers_available,
logging,
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
class IPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -77,7 +77,7 @@ class IPAdapterMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -88,12 +88,12 @@ class IPAdapterMixin:
"""
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
user_agent = {
@@ -110,7 +110,7 @@ class IPAdapterMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
+36 -29
View File
@@ -18,13 +18,14 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub import model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
@@ -131,7 +132,6 @@ class LoraLoaderMixin:
)
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -174,7 +174,7 @@ class LoraLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -195,12 +195,12 @@ class LoraLoaderMixin:
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -239,7 +239,7 @@ class LoraLoaderMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -265,7 +265,7 @@ class LoraLoaderMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -391,10 +391,6 @@ class LoraLoaderMixin:
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
@@ -411,9 +407,8 @@ class LoraLoaderMixin:
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -680,7 +675,8 @@ class LoraLoaderMixin:
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -708,7 +704,8 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -805,21 +802,29 @@ class LoraLoaderMixin:
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
# Populate the dictionary.
if unet_lora_layers is not None:
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
)
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
text_encoder_lora_state_dict = {
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)
# Save the model
cls.write_lora_layers(
@@ -943,7 +948,8 @@ class LoraLoaderMixin:
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -1000,7 +1006,8 @@ class LoraLoaderMixin:
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
+18 -20
View File
@@ -18,9 +18,10 @@ from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
deprecate,
is_accelerate_available,
is_omegaconf_available,
@@ -51,7 +52,6 @@ class FromSingleFileMixin:
return cls.from_single_file(*args, **kwargs)
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
@@ -81,7 +81,7 @@ class FromSingleFileMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -154,12 +154,12 @@ class FromSingleFileMixin:
original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None)
@@ -253,7 +253,7 @@ class FromSingleFileMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
@@ -282,7 +282,7 @@ class FromSingleFileMixin:
)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)
pipe.to(torch_dtype=torch_dtype)
return pipe
@@ -293,7 +293,6 @@ class FromOriginalVAEMixin:
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -323,7 +322,7 @@ class FromOriginalVAEMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -380,12 +379,12 @@ class FromOriginalVAEMixin:
)
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
@@ -426,7 +425,7 @@ class FromOriginalVAEMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
@@ -491,7 +490,6 @@ class FromOriginalControlnetMixin:
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -521,7 +519,7 @@ class FromOriginalControlnetMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -557,12 +555,12 @@ class FromOriginalControlnetMixin:
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
num_in_channels = kwargs.pop("num_in_channels", None)
use_linear_projection = kwargs.pop("use_linear_projection", None)
revision = kwargs.pop("revision", None)
@@ -605,7 +603,7 @@ class FromOriginalControlnetMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
+14 -10
View File
@@ -15,10 +15,16 @@ from typing import Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_accelerate_available,
is_transformers_available,
logging,
)
if is_transformers_available():
@@ -33,14 +39,13 @@ TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@validate_hf_hub_args
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -74,7 +79,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -95,7 +100,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -262,7 +267,6 @@ class TextualInversionLoaderMixin:
return all_tokens, all_embeddings
@validate_hf_hub_args
def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
@@ -316,7 +320,7 @@ class TextualInversionLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
+28 -119
View File
@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from collections import defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.embeddings import ImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file,
delete_adapter_layers,
@@ -61,7 +62,6 @@ class UNet2DConditionLoadersMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
@@ -95,7 +95,7 @@ class UNet2DConditionLoadersMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
@@ -130,12 +130,12 @@ class UNet2DConditionLoadersMixin:
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -184,7 +184,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -204,7 +204,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -672,20 +672,6 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0,
)
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
@@ -709,10 +695,7 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(dtype=self.dtype, device=self.device)
value_dict = {}
@@ -725,100 +708,26 @@ class UNet2DConditionLoadersMixin:
self.set_attn_processor(attn_procs)
# create image projection layers.
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
elif "proj.3.weight" in state_dict["image_proj"]:
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
image_projection.load_state_dict(image_proj_state_dict)
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
+2 -6
View File
@@ -32,9 +32,7 @@ if is_torch_available():
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
@@ -44,7 +42,7 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
@@ -64,9 +62,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .controlnetxs import ControlNetXSModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
@@ -76,7 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
+6 -9
View File
@@ -55,12 +55,11 @@ class GELU(nn.Module):
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
@@ -82,14 +81,13 @@ class GEGLU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -111,12 +109,11 @@ class ApproximateGELU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
+5 -7
View File
@@ -501,7 +501,6 @@ class FeedForward(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
@@ -512,7 +511,6 @@ class FeedForward(nn.Module):
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
bias: bool = True,
):
super().__init__()
inner_dim = int(dim * mult)
@@ -520,13 +518,13 @@ class FeedForward(nn.Module):
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
act_fn = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
@@ -534,7 +532,7 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
+46 -133
View File
@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import einsum, nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
@@ -109,19 +109,15 @@ class Attention(nn.Module):
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.inner_dim = dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -130,7 +126,7 @@ class Attention(nn.Module):
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
@@ -182,7 +178,6 @@ class Attention(nn.Module):
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
@@ -198,7 +193,7 @@ class Attention(nn.Module):
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
@@ -695,32 +690,6 @@ class Attention(nn.Module):
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
self.fused_projections = fuse
class AttnProcessor:
r"""
@@ -1213,6 +1182,9 @@ class AttnProcessor2_0:
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1279,103 +1251,6 @@ class AttnProcessor2_0:
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -2344,6 +2219,44 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""
@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -2369,12 +2282,12 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
-39
View File
@@ -22,7 +22,6 @@ from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -449,41 +448,3 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
-977
View File
@@ -1,977 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.normalization import GroupNorm
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import (
AttentionProcessor,
)
from .autoencoder_kl import AutoencoderKL
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
Downsample2D,
ResnetBlock2D,
Transformer2DModel,
UpBlock2D,
Upsample2D,
)
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetXSOutput(BaseOutput):
"""
The output of [`ControlNetXSModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model
output, but is already the final output.
"""
sample: torch.FloatTensor = None
# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 latent images for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetXSModel(ModelMixin, ConfigMixin):
r"""
A ControlNet-XS model
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation
of [`UNet2DConditionModel`] for them.
Parameters:
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
time_embedding_input_dim (`int`, defaults to 320):
Dimension of input into time embedding. Needs to be same as in the base model.
time_embedding_dim (`int`, defaults to 1280):
Dimension of output from time embedding. Needs to be same as in the base model.
learn_embedding (`bool`, defaults to `False`):
Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of
the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`.
time_embedding_mix (`float`, defaults to 1.0):
Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the
control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used.
base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`):
Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it.
"""
@classmethod
def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True):
"""
Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS).
Parameters:
base_model (`UNet2DConditionModel`):
Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL.
is_sdxl (`bool`, defaults to `True`):
Whether passed `base_model` is a StableDiffusion-XL model.
"""
def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int):
"""
Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why).
The original ControlNet-XS model, however, define the number of attention heads.
That's why compute the dimensions needed to get the correct number of attention heads.
"""
block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels]
dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels]
return dim_attn_heads
if is_sdxl:
return ControlNetXSModel.from_unet(
base_model,
time_embedding_mix=0.95,
learn_embedding=True,
size_ratio=0.1,
conditioning_embedding_out_channels=(16, 32, 96, 256),
num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64),
)
else:
return ControlNetXSModel.from_unet(
base_model,
time_embedding_mix=1.0,
learn_embedding=True,
size_ratio=0.0125,
conditioning_embedding_out_channels=(16, 32, 96, 256),
num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8),
)
@classmethod
def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str):
"""To create correctly sized connections between base and control model, we need to know
the input and output channels of each subblock.
Parameters:
unet (`UNet2DConditionModel`):
Unet of which the subblock channels sizes are to be gathered.
base_or_control (`str`):
Needs to be either "base" or "control". If "base", decoder is also considered.
"""
if base_or_control not in ["base", "control"]:
raise ValueError("`base_or_control` needs to be either `base` or `control`")
channel_sizes = {"down": [], "mid": [], "up": []}
# input convolution
channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels))
# encoder blocks
for module in unet.down_blocks:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
for r in module.resnets:
channel_sizes["down"].append((r.in_channels, r.out_channels))
if module.downsamplers:
channel_sizes["down"].append(
(module.downsamplers[0].channels, module.downsamplers[0].out_channels)
)
else:
raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.")
# middle block
channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels))
# decoder blocks
if base_or_control == "base":
for module in unet.up_blocks:
if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
for r in module.resnets:
channel_sizes["up"].append((r.in_channels, r.out_channels))
else:
raise ValueError(
f"Encountered unknown module of type {type(module)} while creating ControlNet-XS."
)
return channel_sizes
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
controlnet_conditioning_channel_order: str = "rgb",
time_embedding_input_dim: int = 320,
time_embedding_dim: int = 1280,
time_embedding_mix: float = 1.0,
learn_embedding: bool = False,
base_model_channel_sizes: Dict[str, List[Tuple[int]]] = {
"down": [
(4, 320),
(320, 320),
(320, 320),
(320, 320),
(320, 640),
(640, 640),
(640, 640),
(640, 1280),
(1280, 1280),
],
"mid": [(1280, 1280)],
"up": [
(2560, 1280),
(2560, 1280),
(1920, 1280),
(1920, 640),
(1280, 640),
(960, 640),
(960, 320),
(640, 320),
(640, 320),
],
},
sample_size: Optional[int] = None,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
upcast_attention: bool = False,
):
super().__init__()
# 1 - Create control unet
self.control_model = UNet2DConditionModel(
sample_size=sample_size,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
attention_head_dim=num_attention_heads,
use_linear_projection=True,
upcast_attention=upcast_attention,
time_embedding_dim=time_embedding_dim,
)
# 2 - Do model surgery on control model
# 2.1 - Allow to use the same time information as the base model
adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim)
# 2.2 - Allow for information infusion from base model
# We concat the output of each base encoder subblocks to the input of the next control encoder subblock
# (We ignore the 1st element, as it represents the `conv_in`.)
extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]]
it_extra_input_channels = iter(extra_input_channels)
for b, block in enumerate(self.control_model.down_blocks):
for r in range(len(block.resnets)):
increase_block_input_in_encoder_resnet(
self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels)
)
if block.downsamplers:
increase_block_input_in_encoder_downsampler(
self.control_model, block_no=b, by=next(it_extra_input_channels)
)
increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1])
# 2.3 - Make group norms work with modified channel sizes
adjust_group_norms(self.control_model)
# 3 - Gather Channel Sizes
self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control")
self.ch_inout_base = base_model_channel_sizes
# 4 - Build connections between base and control model
self.down_zero_convs_out = nn.ModuleList([])
self.down_zero_convs_in = nn.ModuleList([])
self.middle_block_out = nn.ModuleList([])
self.middle_block_in = nn.ModuleList([])
self.up_zero_convs_out = nn.ModuleList([])
self.up_zero_convs_in = nn.ModuleList([])
for ch_io_base in self.ch_inout_base["down"]:
self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1]))
for i in range(len(self.ch_inout_ctrl["down"])):
self.down_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1])
)
self.middle_block_out = self._make_zero_conv(
self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1]
)
self.up_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1])
)
for i in range(1, len(self.ch_inout_ctrl["down"])):
self.up_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1])
)
# 5 - Create conditioning hint embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
# In the mininal implementation setting, we only need the control model up to the mid block
del self.control_model.up_blocks
del self.control_model.conv_norm_out
del self.control_model.conv_out
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
conditioning_channels: int = 3,
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
controlnet_conditioning_channel_order: str = "rgb",
learn_embedding: bool = False,
time_embedding_mix: float = 1.0,
block_out_channels: Optional[Tuple[int]] = None,
size_ratio: Optional[float] = None,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
norm_num_groups: Optional[int] = None,
):
r"""
Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it.
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
learn_embedding (`bool`, defaults to `False`):
Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation
of the time embeddings of the control and base model with interpolation parameter
`time_embedding_mix**3`.
time_embedding_mix (`float`, defaults to 1.0):
Linear interpolation parameter used if `learn_embedding` is `True`.
block_out_channels (`Tuple[int]`, *optional*):
Down blocks output channels in control model. Either this or `size_ratio` must be given.
size_ratio (float, *optional*):
When given, block_out_channels is set to a relative fraction of the base model's block_out_channels.
Either this or `block_out_channels` must be given.
num_attention_heads (`Union[int, Tuple[int]]`, *optional*):
The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
norm_num_groups (int, *optional*, defaults to `None`):
The number of groups to use for the normalization of the control unet. If `None`,
`int(unet.config.norm_num_groups * size_ratio)` is taken.
"""
# Check input
fixed_size = block_out_channels is not None
relative_size = size_ratio is not None
if not (fixed_size ^ relative_size):
raise ValueError(
"Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)."
)
# Create model
if block_out_channels is None:
block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels]
# Check that attention heads and group norms match channel sizes
# - attention heads
def attn_heads_match_channel_sizes(attn_heads, channel_sizes):
if isinstance(attn_heads, (tuple, list)):
return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes))
else:
return all(c % attn_heads == 0 for c in channel_sizes)
num_attention_heads = num_attention_heads or unet.config.attention_head_dim
if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels):
raise ValueError(
f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually."
)
# - group norms
def group_norms_match_channel_sizes(num_groups, channel_sizes):
return all(c % num_groups == 0 for c in channel_sizes)
if norm_num_groups is None:
if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels):
norm_num_groups = unet.config.norm_num_groups
else:
norm_num_groups = min(block_out_channels)
if group_norms_match_channel_sizes(norm_num_groups, block_out_channels):
print(
f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information."
)
else:
raise ValueError(
f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels."
)
def get_time_emb_input_dim(unet: UNet2DConditionModel):
return unet.time_embedding.linear_1.in_features
def get_time_emb_dim(unet: UNet2DConditionModel):
return unet.time_embedding.linear_2.out_features
# Clone params from base unet if
# (i) it's required to build SD or SDXL, and
# (ii) it's not used for the time embedding (as time embedding of control model is never used), and
# (iii) it's not set further below anyway
to_keep = [
"cross_attention_dim",
"down_block_types",
"sample_size",
"transformer_layers_per_block",
"up_block_types",
"upcast_attention",
]
kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep}
kwargs.update(block_out_channels=block_out_channels)
kwargs.update(num_attention_heads=num_attention_heads)
kwargs.update(norm_num_groups=norm_num_groups)
# Add controlnetxs-specific params
kwargs.update(
conditioning_channels=conditioning_channels,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
time_embedding_input_dim=get_time_emb_input_dim(unet),
time_embedding_dim=get_time_emb_dim(unet),
time_embedding_mix=time_embedding_mix,
learn_embedding=learn_embedding,
base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"),
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
return cls(**kwargs)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
return self.control_model.attn_processors
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
self.control_model.set_attn_processor(processor, _remove_lora)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.control_model.set_default_attn_processor()
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.control_model.set_attention_slice(slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (UNet2DConditionModel)):
if value:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def forward(
self,
base_model: UNet2DConditionModel,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[ControlNetXSOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
base_model (`UNet2DConditionModel`):
The base unet model we want to control.
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# scale control strength
n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out)
scale_list = torch.full((n_connections,), conditioning_scale)
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = base_model.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
if self.config.learn_embedding:
ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond)
base_temb = base_model.time_embedding(t_emb, timestep_cond)
interpolation_param = self.config.time_embedding_mix**0.3
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
else:
temb = base_model.time_embedding(t_emb)
# added time & text embeddings
aug_emb = None
if base_model.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if base_model.config.class_embed_type == "timestep":
class_labels = base_model.time_proj(class_labels)
class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype)
temb = temb + class_emb
if base_model.config.addition_embed_type is not None:
if base_model.config.addition_embed_type == "text":
aug_emb = base_model.add_embedding(encoder_hidden_states)
elif base_model.config.addition_embed_type == "text_image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = base_model.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(temb.dtype)
aug_emb = base_model.add_embedding(add_embeds)
elif base_model.config.addition_embed_type == "image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "image_hint":
raise NotImplementedError()
temb = temb + aug_emb if aug_emb is not None else temb
# text embeddings
cemb = encoder_hidden_states
# Preparation
guided_hint = self.controlnet_cond_embedding(controlnet_cond)
h_ctrl = h_base = sample
hs_base, hs_ctrl = [], []
it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map(
iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out)
)
scales = iter(scale_list)
base_down_subblocks = to_sub_blocks(base_model.down_blocks)
ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks)
base_mid_subblocks = to_sub_blocks([base_model.mid_block])
ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block])
base_up_subblocks = to_sub_blocks(base_model.up_blocks)
# Cross Control
# 0 - conv in
h_base = base_model.conv_in(h_base)
h_ctrl = self.control_model.conv_in(h_ctrl)
if guided_hint is not None:
h_ctrl += guided_hint
h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
# 1 - down
for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks):
h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
# 2 - mid
h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks):
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base
# 3 - up
for i, m_base in enumerate(base_up_subblocks):
h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder
h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs)
h_base = base_model.conv_norm_out(h_base)
h_base = base_model.conv_act(h_base)
h_base = base_model.conv_out(h_base)
if not return_dict:
return h_base
return ControlNetXSOutput(sample=h_base)
def _make_zero_conv(self, in_channels, out_channels=None):
# keep running track of channels sizes
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
@torch.no_grad()
def _check_if_vae_compatible(self, vae: AutoencoderKL):
condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1)
vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
compatible = condition_downscale_factor == vae_downscale_factor
return compatible, condition_downscale_factor, vae_downscale_factor
class SubBlock(nn.ModuleList):
"""A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively.
Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base.
"""
def __init__(self, ms, *args, **kwargs):
if not is_iterable(ms):
ms = [ms]
super().__init__(ms, *args, **kwargs)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor,
cemb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
"""Iterate through children and pass correct information to each."""
for m in self:
if isinstance(m, ResnetBlock2D):
x = m(x, temb)
elif isinstance(m, Transformer2DModel):
x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample
elif isinstance(m, Downsample2D):
x = m(x)
elif isinstance(m, Upsample2D):
x = m(x)
else:
raise ValueError(
f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`"
)
return x
def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int):
unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim)
def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
r = unet.down_blocks[block_no].resnets[resnet_idx]
old_norm1, old_conv1 = r.norm1, r.conv1
# norm
norm_args = "num_groups num_channels eps affine".split(" ")
for a in norm_args:
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
# conv_shortcut
# as we changed the input size of the block, the input and output sizes are likely different,
# therefore we need a conv_shortcut (simply adding won't work)
conv_shortcut_args_kwargs = {
"in_channels": conv1_kwargs["in_channels"],
"out_channels": conv1_kwargs["out_channels"],
# default arguments from resnet.__init__
"kernel_size": 1,
"stride": 1,
"padding": 0,
"bias": True,
}
# swap old with new modules
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
old_down = unet.down_blocks[block_no].downsamplers[0].conv
# conv1
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
" "
)
for a in args:
assert hasattr(old_down, a)
kwargs = {a: getattr(old_down, a) for a in args}
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
kwargs["in_channels"] += by # surgery done here
# swap old with new modules
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
m = unet.mid_block.resnets[0]
old_norm1, old_conv1 = m.norm1, m.conv1
# norm
norm_args = "num_groups num_channels eps affine".split(" ")
for a in norm_args:
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
# conv_shortcut
# as we changed the input size of the block, the input and output sizes are likely different,
# therefore we need a conv_shortcut (simply adding won't work)
conv_shortcut_args_kwargs = {
"in_channels": conv1_kwargs["in_channels"],
"out_channels": conv1_kwargs["out_channels"],
# default arguments from resnet.__init__
"kernel_size": 1,
"stride": 1,
"padding": 0,
"bias": True,
}
# swap old with new modules
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.mid_block.resnets[0].in_channels += by # surgery done here
def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32):
def find_denominator(number, start):
if start >= number:
return number
while start != 0:
residual = number % start
if residual == 0:
return start
start -= 1
for block in [*unet.down_blocks, unet.mid_block]:
# resnets
for r in block.resnets:
if r.norm1.num_groups < max_num_group:
r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group)
if r.norm2.num_groups < max_num_group:
r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group)
# transformers
if hasattr(block, "attentions"):
for a in block.attentions:
if a.norm.num_groups < max_num_group:
a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group)
def is_iterable(o):
if isinstance(o, str):
return False
try:
iter(o)
return True
except TypeError:
return False
def to_sub_blocks(blocks):
if not is_iterable(blocks):
blocks = [blocks]
sub_blocks = []
for b in blocks:
if hasattr(b, "resnets"):
if hasattr(b, "attentions") and b.attentions is not None:
for r, a in zip(b.resnets, b.attentions):
sub_blocks.append([r, a])
num_resnets = len(b.resnets)
num_attns = len(b.attentions)
if num_resnets > num_attns:
# we can have more resnets than attentions, so add each resnet as separate subblock
for i in range(num_attns, num_resnets):
sub_blocks.append([b.resnets[i]])
else:
for r in b.resnets:
sub_blocks.append([r])
# upsamplers are part of the same subblock
if hasattr(b, "upsamplers") and b.upsamplers is not None:
for u in b.upsamplers:
sub_blocks[-1].extend([u])
# downsamplers are own subblock
if hasattr(b, "downsamplers") and b.downsamplers is not None:
for d in b.downsamplers:
sub_blocks.append([d])
return list(map(SubBlock, sub_blocks))
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
-101
View File
@@ -20,7 +20,6 @@ from torch import nn
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear
@@ -461,18 +460,6 @@ class ImageProjection(nn.Module):
return image_embeds
class MLPProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
return self.norm(self.ff(image_embeds))
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -803,91 +790,3 @@ class CaptionProjection(nn.Module):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class Resampler(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
dim_head (int): The number of head channels. Defaults to 64.
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward # Lazy import to avoid circular import
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
self.proj_in = nn.Linear(embed_dims, hidden_dims)
self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
nn.LayerNorm(hidden_dims),
nn.LayerNorm(hidden_dims),
Attention(
query_dim=hidden_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
]
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
----
x (torch.Tensor): Input Tensor.
Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for ln0, ln1, attn, ff in self.layers:
residual = latents
encoder_hidden_states = ln0(x)
latents = ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = attn(latents, encoder_hidden_states) + residual
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
+7 -12
View File
@@ -24,17 +24,13 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
@@ -201,7 +197,6 @@ class FlaxModelMixin(PushToHubMixin):
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -293,13 +288,13 @@ class FlaxModelMixin(PushToHubMixin):
```
"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
@@ -319,7 +314,7 @@ class FlaxModelMixin(PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
@@ -364,7 +359,7 @@ class FlaxModelMixin(PushToHubMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
@@ -374,7 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `token` or log in with `huggingface-cli "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
+10 -10
View File
@@ -25,13 +25,14 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch
from huggingface_hub import create_repo
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
from .. import __version__
from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
@@ -534,7 +535,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
-37
View File
@@ -25,7 +25,6 @@ from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -795,42 +794,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: torch.FloatTensor,
@@ -1,28 +1,16 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
from .embeddings import TimestepEmbedding
from .modeling_utils import ModelMixin
@@ -34,6 +22,36 @@ class Kandinsky3UNetOutput(BaseOutput):
sample: torch.FloatTensor = None
# TODO(Yiyi): This class needs to be removed
def set_default_item(condition, item_1, item_2=None):
if condition:
return item_1
else:
return item_2
# TODO(Yiyi): This class needs to be removed
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
if condition:
return layer_1(*args_1, **kwargs_1)
else:
return layer_2(*args_2, **kwargs_2)
# TODO(Yiyi): This class should be removed and be replaced by Timesteps
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, type_tensor=None):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
emb = x[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class Kandinsky3EncoderProj(nn.Module):
def __init__(self, encoder_hid_dim, cross_attention_dim):
super().__init__()
@@ -69,7 +87,9 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
out_channels = in_channels
init_channels = block_out_channels[0] // 2
self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
# TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
# self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
self.time_proj = SinusoidalPosEmb(init_channels)
self.time_embedding = TimestepEmbedding(
init_channels,
@@ -86,7 +106,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
hidden_dims = [init_channels] + list(block_out_channels)
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention]
text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
num_blocks = len(block_out_channels) * [layers_per_block]
layer_params = [num_blocks, text_dims, add_self_attention]
rev_layer_params = map(reversed, layer_params)
@@ -98,7 +118,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
zip(in_out_dims, *layer_params)
):
down_sample = level != (self.num_levels - 1)
cat_dims.append(out_dim if level != (self.num_levels - 1) else 0)
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
self.down_blocks.append(
Kandinsky3DownSampleBlock(
in_dim,
@@ -203,16 +223,18 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
self.set_attn_processor(Kandi3AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# TODO(Yiyi): Clean up the following variables - these names should not be used
# but instead only the ones that we pass to forward
x = sample
context_mask = encoder_attention_mask
context = encoder_hidden_states
if not torch.is_tensor(timestep):
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
@@ -222,33 +244,33 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(sample.shape[0])
time_embed_input = self.time_proj(timestep).to(sample.dtype)
time_embed_input = self.time_proj(timestep).to(x.dtype)
time_embed = self.time_embedding(time_embed_input)
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
context = self.encoder_hid_proj(context)
if encoder_hidden_states is not None:
time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
if context is not None:
time_embed = self.add_time_condition(time_embed, context, context_mask)
hidden_states = []
sample = self.conv_in(sample)
x = self.conv_in(x)
for level, down_sample in enumerate(self.down_blocks):
sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
x = down_sample(x, time_embed, context, context_mask)
if level != self.num_levels - 1:
hidden_states.append(sample)
hidden_states.append(x)
for level, up_sample in enumerate(self.up_blocks):
if level != 0:
sample = torch.cat([sample, hidden_states.pop()], dim=1)
sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
x = torch.cat([x, hidden_states.pop()], dim=1)
x = up_sample(x, time_embed, context, context_mask)
sample = self.conv_norm_out(sample)
sample = self.conv_act_out(sample)
sample = self.conv_out(sample)
x = self.conv_norm_out(x)
x = self.conv_act_out(x)
x = self.conv_out(x)
if not return_dict:
return (sample,)
return Kandinsky3UNetOutput(sample=sample)
return (x,)
return Kandinsky3UNetOutput(sample=x)
class Kandinsky3UpSampleBlock(nn.Module):
@@ -268,7 +290,7 @@ class Kandinsky3UpSampleBlock(nn.Module):
self_attention=True,
):
super().__init__()
up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
hidden_channels = (
[(in_channels + cat_dim, in_channels)]
+ [(in_channels, in_channels)] * (num_blocks - 2)
@@ -281,27 +303,27 @@ class Kandinsky3UpSampleBlock(nn.Module):
self.self_attention = self_attention
self.context_dim = context_dim
if self_attention:
attentions.append(
Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
else:
attentions.append(nn.Identity())
)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
if context_dim is not None:
attentions.append(
Kandinsky3AttentionBlock(
in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
else:
attentions.append(nn.Identity())
)
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
@@ -345,29 +367,29 @@ class Kandinsky3DownSampleBlock(nn.Module):
self.self_attention = self_attention
self.context_dim = context_dim
if self_attention:
attentions.append(
Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
else:
attentions.append(nn.Identity())
)
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
if context_dim is not None:
attentions.append(
Kandinsky3AttentionBlock(
out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
else:
attentions.append(nn.Identity())
)
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
@@ -409,23 +431,68 @@ class Kandinsky3ConditionalGroupNorm(nn.Module):
return x
# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
# sure we can delete it and instead just pass an attention_mask
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim**-0.5
# to_q
self.to_q = nn.Linear(in_channels, out_channels, bias=False)
# to_k
self.to_k = nn.Linear(context_dim, out_channels, bias=False)
# to_v
self.to_v = nn.Linear(context_dim, out_channels, bias=False)
processor = Kandi3AttnProcessor()
self.set_processor(processor)
# to_out
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
def set_processor(self, processor: "AttnProcessor"): # noqa: F821
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, x, context, context_mask=None, image_mask=None):
return self.processor(
self,
x,
context=context,
context_mask=context_mask,
)
class Kandinsky3Block(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
super().__init__()
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
self.activation = nn.SiLU()
if up_resolution is not None and up_resolution:
self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
else:
self.up_sample = nn.Identity()
self.up_sample = set_default_layer(
up_resolution is not None and up_resolution,
nn.ConvTranspose2d,
(in_channels, in_channels),
{"kernel_size": 2, "stride": 2},
)
padding = int(kernel_size > 1)
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
if up_resolution is not None and not up_resolution:
self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
else:
self.down_sample = nn.Identity()
self.down_sample = set_default_layer(
up_resolution is not None and not up_resolution,
nn.Conv2d,
(out_channels, out_channels),
{"kernel_size": 2, "stride": 2},
)
def forward(self, x, time_embed):
x = self.group_norm(x, time_embed)
@@ -454,18 +521,14 @@ class Kandinsky3ResNetBlock(nn.Module):
)
]
)
self.shortcut_up_sample = (
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
if True in up_resolutions
else nn.Identity()
self.shortcut_up_sample = set_default_layer(
True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
)
self.shortcut_projection = (
nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
self.shortcut_projection = set_default_layer(
in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
)
self.shortcut_down_sample = (
nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
if False in up_resolutions
else nn.Identity()
self.shortcut_down_sample = set_default_layer(
False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
)
def forward(self, x, time_embed):
@@ -483,16 +546,9 @@ class Kandinsky3ResNetBlock(nn.Module):
class Kandinsky3AttentionPooling(nn.Module):
def __init__(self, num_channels, context_dim, head_dim=64):
super().__init__()
self.attention = Attention(
context_dim,
context_dim,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
def forward(self, x, context, context_mask=None):
context_mask = context_mask.to(dtype=context.dtype)
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
return x + context.squeeze(1)
@@ -501,13 +557,7 @@ class Kandinsky3AttentionBlock(nn.Module):
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
super().__init__()
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.attention = Attention(
num_channels,
context_dim or num_channels,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
hidden_channels = expansion_ratio * num_channels
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
@@ -522,10 +572,14 @@ class Kandinsky3AttentionBlock(nn.Module):
out = self.in_norm(x, time_embed)
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
context = context if context is not None else out
if context_mask is not None:
context_mask = context_mask.to(dtype=context.dtype)
out = self.attention(out, context, context_mask)
if image_mask is not None:
mask_height, mask_width = image_mask.shape[-2:]
kernel_size = (mask_height // height, mask_width // width)
image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
image_mask = image_mask.reshape(image_mask.shape[0], -1)
out = self.attention(out, context, context_mask, image_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out
-11
View File
@@ -19,7 +19,6 @@ from ..utils import (
_dummy_objects = {}
_import_structure = {
"controlnet": [],
"controlnet_xs": [],
"latent_diffusion": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
@@ -94,12 +93,6 @@ else:
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
"StableDiffusionXLControlNetXSPipeline",
]
)
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -354,10 +347,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .controlnet_xs import (
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -494,29 +494,18 @@ class AltDiffusionPipeline(
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -886,10 +875,7 @@ class AltDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -505,29 +505,18 @@ class AltDiffusionImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -930,10 +919,7 @@ class AltDiffusionImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...schedulers import (
@@ -320,29 +320,18 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
@@ -662,10 +651,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
+13 -17
View File
@@ -16,9 +16,8 @@
import inspect
from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import DIFFUSERS_CACHE
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -196,7 +195,6 @@ class AutoPipelineForText2Image(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -248,7 +246,7 @@ class AutoPipelineForText2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -312,11 +310,11 @@ class AutoPipelineForText2Image(ConfigMixin):
>>> image = pipeline(prompt).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -325,7 +323,7 @@ class AutoPipelineForText2Image(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}
@@ -468,7 +466,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -521,7 +518,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -585,11 +582,11 @@ class AutoPipelineForImage2Image(ConfigMixin):
>>> image = pipeline(prompt, image).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -598,7 +595,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}
@@ -745,7 +742,6 @@ class AutoPipelineForInpainting(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -797,7 +793,7 @@ class AutoPipelineForInpainting(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -861,11 +857,11 @@ class AutoPipelineForInpainting(ConfigMixin):
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -874,7 +870,7 @@ class AutoPipelineForInpainting(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -479,29 +479,18 @@ class StableDiffusionControlNetPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -1078,10 +1067,7 @@ class StableDiffusionControlNetPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -597,29 +597,18 @@ class StableDiffusionControlNetInpaintPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -1295,10 +1284,7 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -37,7 +37,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -489,29 +489,18 @@ class StableDiffusionXLControlNetPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -1180,10 +1169,7 @@ class StableDiffusionXLControlNetPipeline(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -1,68 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,944 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> import cv2
>>> from PIL import Image
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
>>> negative_prompt = "low quality, bad quality, sketches"
>>> # download an image
>>> image = load_image(
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
... )
>>> # initialize the models and pipeline
>>> controlnet_conditioning_scale = 0.5
>>> controlnet = ControlNetXSModel.from_pretrained(
... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
>>> # get canny image
>>> image = np.array(image)
>>> image = cv2.Canny(image, 100, 200)
>>> image = image[:, :, None]
>>> image = np.concatenate([image, image, image], axis=2)
>>> canny_image = Image.fromarray(image)
>>> # generate image
>>> image = pipe(
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
... ).images[0]
```
"""
class StableDiffusionControlNetXSPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
controlnet ([`ControlNetXSModel`]):
Provides additional conditioning to the `unet` during the denoising process.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetXSModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
vae
)
if not vae_compatible:
raise ValueError(
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
image,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
assert False
start, end = control_guidance_start, control_guidance_end
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError(
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
)
if image_is_pil:
image_batch_size = 1
else:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
control_guidance_end: float = 1.0,
clip_skip: Optional[int] = None,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
input to a single ControlNet.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetXSModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
else:
assert False
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
dont_control = (
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
)
if dont_control:
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=True,
).sample
else:
noise_pred = self.controlnet(
base_model=self.unet,
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=True,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
File diff suppressed because it is too large Load Diff
@@ -21,8 +21,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]
_import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -33,8 +33,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_kandinsky3 import Kandinsky3Pipeline
from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
from .kandinsky3_pipeline import Kandinsky3Pipeline
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
else:
import sys

Some files were not shown because too many files have changed in this diff Show More