Compare commits
114 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a99b0eb517 | |||
| 9112028ed8 | |||
| dce06680d2 | |||
| dd63168319 | |||
| 1040dfd9cc | |||
| 49a4b377c1 | |||
| dff35a86e4 | |||
| 8842bcadb9 | |||
| 181280baba | |||
| 53f498d2a4 | |||
| 990860911f | |||
| 23eed39702 | |||
| fefed44543 | |||
| 814f56d2fe | |||
| 96d6e16550 | |||
| c11de13588 | |||
| 357855f8fc | |||
| f825221b5d | |||
| 119d734f6e | |||
| cb4b3f0b78 | |||
| 3d574b3bbe | |||
| 09903774d9 | |||
| d6a70d8ba8 | |||
| e3103e171f | |||
| b053053ac9 | |||
| 08702fc1cb | |||
| 7ce89e979c | |||
| 05faf3263b | |||
| a080f0d3a2 | |||
| 79df50388d | |||
| 7d631825b0 | |||
| 33d2b5b087 | |||
| f486d34b04 | |||
| e44b205e0b | |||
| 60cb44323d | |||
| 1dd0ac9401 | |||
| c6b04589b6 | |||
| 5dc3471380 | |||
| 9df566e6da | |||
| be0b425762 | |||
| da843b3d53 | |||
| 17cece072a | |||
| a551ddf928 | |||
| 1d57892980 | |||
| 3e8b63216e | |||
| dd4459ad79 | |||
| 6313645b6b | |||
| 2d1f2182cc | |||
| 3be7c96e28 | |||
| 3c79dd9dbe | |||
| 9d767916da | |||
| ae7cd5ad4c | |||
| 4497b3ec98 | |||
| fc63ebdd3a | |||
| 8b6fae4ce5 | |||
| aa1797e109 | |||
| 5bacc2f5af | |||
| 6ae7e8112a | |||
| e0f349c2b0 | |||
| 774f5c4581 | |||
| a483a8eddf | |||
| 9a9daee724 | |||
| 585f941366 | |||
| 86a26761ac | |||
| 6ef2b8a92f | |||
| 3848606c7e | |||
| 2a97067b84 | |||
| ae060fc4f1 | |||
| 9d945b2b90 | |||
| d184291c7d | |||
| 0a0bb526aa | |||
| 2fada8dc1b | |||
| f2d51a28f7 | |||
| 811fd06292 | |||
| f3d1333e02 | |||
| acd926f4f2 | |||
| 691d8d3e15 | |||
| e7c0af5e71 | |||
| 107e02160a | |||
| 6dbef45e6e | |||
| 7715e6c31c | |||
| 05b3d36a25 | |||
| fb4aec0ce3 | |||
| 63de23e3db | |||
| 2993257f2a | |||
| aad18faa3e | |||
| d700140076 | |||
| 2e4dc3e25d | |||
| 3e2961f0b4 | |||
| 79c380bc80 | |||
| e30b661437 | |||
| b4077af212 | |||
| 9f2bff502e | |||
| 0cb92717f9 | |||
| 86714b72d0 | |||
| 61f6c5472a | |||
| 17546020fc | |||
| 8a366b835c | |||
| 61d223c884 | |||
| bf725e044e | |||
| 1622265e13 | |||
| 0b63ad5ad5 | |||
| 6a376ceea2 | |||
| 9f283b01d2 | |||
| 203724e9d9 | |||
| e7044a4221 | |||
| 034b39b8cb | |||
| 2db73f4a50 | |||
| 84d7faebe4 | |||
| 4c483deb90 | |||
| 1ac07d8a8d | |||
| 1fff527702 | |||
| 645a62bf3b | |||
| 6414d4e4f9 |
@@ -1,4 +1,4 @@
|
||||
contact_links:
|
||||
- name: Forum
|
||||
url: https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63
|
||||
- name: Questions / Discussions
|
||||
url: https://github.com/huggingface/diffusers/discussions
|
||||
about: General usage questions and community discussions
|
||||
|
||||
@@ -38,7 +38,7 @@ members/contributors who may be interested in your PR.
|
||||
|
||||
Core library:
|
||||
|
||||
- Schedulers: @williamberman and @patrickvonplaten
|
||||
- Schedulers: @yiyixuxu and @patrickvonplaten
|
||||
- Pipelines: @patrickvonplaten and @sayakpaul
|
||||
- Training examples: @sayakpaul and @patrickvonplaten
|
||||
- Docs: @stevhliu and @yiyixuxu
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
name: Delete doc comment
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Delete doc comment trigger"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
||||
secrets:
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
||||
@@ -1,12 +0,0 @@
|
||||
name: Delete doc comment trigger
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [ closed ]
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/test_lora_layers_peft.py
|
||||
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
|
||||
--make-reports=tests_peft_cuda \
|
||||
tests/lora/
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf \
|
||||
pytorch-lightning
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -160,6 +160,8 @@
|
||||
title: xFormers
|
||||
- local: optimization/tome
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
title: General optimizations
|
||||
- sections:
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
@@ -210,6 +212,8 @@
|
||||
title: Textual Inversion
|
||||
- local: api/loaders/unet
|
||||
title: UNet
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
title: Loaders
|
||||
- sections:
|
||||
- local: api/models/overview
|
||||
@@ -329,6 +333,8 @@
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# PEFT
|
||||
|
||||
Diffusers supports loading adapters such as [LoRA](../../using-diffusers/loading_adapters) with the [PEFT](https://huggingface.co/docs/peft/index) library with the [`~loaders.peft.PeftAdapterMixin`] class. This allows modeling classes in Diffusers like [`UNet2DConditionModel`] to load an adapter.
|
||||
|
||||
<Tip>
|
||||
|
||||
Refer to the [Inference with PEFT](../../tutorials/using_peft_for_inference.md) tutorial for an overview of how to use PEFT in Diffusers for inference.
|
||||
|
||||
</Tip>
|
||||
|
||||
## PeftAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.peft.PeftAdapterMixin
|
||||
@@ -12,10 +12,16 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# aMUSEd
|
||||
|
||||
Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
|
||||
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
|
||||
|
||||
Amused is a lightweight text to image model based off of the [MUSE](https://arxiv.org/abs/2301.00704) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
|
||||
|
||||
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present aMUSEd, an open-source, lightweight masked image model (MIM) for text-to-image generation based on MUSE. With 10 percent of MUSE's parameters, aMUSEd is focused on fast image generation. We believe MIM is under-explored compared to latent diffusion, the prevailing approach for text-to-image generation. Compared to latent diffusion, MIM requires fewer inference steps and is more interpretable. Additionally, MIM can be fine-tuned to learn additional styles with only a single image. We hope to encourage further exploration of MIM by demonstrating its effectiveness on large-scale text-to-image generation and releasing reproducible training code. We also release checkpoints for two models which directly produce images at 256x256 and 512x512 resolutions.*
|
||||
|
||||
| Model | Params |
|
||||
|-------|--------|
|
||||
| [amused-256](https://huggingface.co/amused/amused-256) | 603M |
|
||||
|
||||
@@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif")
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using FreeInit
|
||||
|
||||
[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
|
||||
|
||||
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
|
||||
|
||||
The following example demonstrates the usage of FreeInit.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
# enable memory savings
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
# enable FreeInit
|
||||
# Refer to the enable_free_init documentation for a full list of configurable parameters
|
||||
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
|
||||
|
||||
# run inference
|
||||
output = pipe(
|
||||
prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
|
||||
negative_prompt="bad quality, worse quality",
|
||||
num_frames=16,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=20,
|
||||
generator=torch.Generator("cpu").manual_seed(666),
|
||||
)
|
||||
|
||||
# disable FreeInit
|
||||
pipe.disable_free_init()
|
||||
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
@@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- __call__
|
||||
- enable_freeu
|
||||
- disable_freeu
|
||||
- enable_free_init
|
||||
- disable_free_init
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_vae_tiling
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# K-Diffusion
|
||||
|
||||
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
|
||||
|
||||
Note that most the samplers from k-diffusion are implemented in Diffusers and we recommend using existing schedulers. You can find a mapping between k-diffusion samplers and schedulers in Diffusers [here](https://huggingface.co/docs/diffusers/api/schedulers/overview)
|
||||
|
||||
|
||||
## StableDiffusionKDiffusionPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionKDiffusionPipeline
|
||||
|
||||
|
||||
## StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLKDiffusionPipeline
|
||||
@@ -37,8 +37,10 @@ source .env/bin/activate
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
|
||||
```bash
|
||||
pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DeepCache
|
||||
[DeepCache](https://huggingface.co/papers/2312.00858) accelerates [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`] by strategically caching and reusing high-level features while efficiently updating low-level features by taking advantage of the U-Net architecture.
|
||||
|
||||
Start by installing [DeepCache](https://github.com/horseee/DeepCache):
|
||||
```bash
|
||||
pip install DeepCache
|
||||
```
|
||||
|
||||
Then load and enable the [`DeepCacheSDHelper`](https://github.com/horseee/DeepCache#usage):
|
||||
|
||||
```diff
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
+ from DeepCache import DeepCacheSDHelper
|
||||
+ helper = DeepCacheSDHelper(pipe=pipe)
|
||||
+ helper.set_params(
|
||||
+ cache_interval=3,
|
||||
+ cache_branch_id=0,
|
||||
+ )
|
||||
+ helper.enable()
|
||||
|
||||
image = pipe("a photo of an astronaut on a moon").images[0]
|
||||
```
|
||||
|
||||
The `set_params` method accepts two arguments: `cache_interval` and `cache_branch_id`. `cache_interval` means the frequency of feature caching, specified as the number of steps between each cache operation. `cache_branch_id` identifies which branch of the network (ordered from the shallowest to the deepest layer) is responsible for executing the caching processes.
|
||||
Opting for a lower `cache_branch_id` or a larger `cache_interval` can lead to faster inference speed at the expense of reduced image quality (ablation experiments of these two hyperparameters can be found in the [paper](https://arxiv.org/abs/2312.00858)). Once those arguments are set, use the `enable` or `disable` methods to activate or deactivate the `DeepCacheSDHelper`.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/horseee/Diffusion_DeepCache/raw/master/static/images/example.png">
|
||||
</div>
|
||||
|
||||
You can find more generated samples (original pipeline vs DeepCache) and the corresponding inference latency in the [WandB report](https://wandb.ai/horseee/DeepCache/runs/jwlsqqgt?workspace=user-horseee). The prompts are randomly selected from the [MS-COCO 2017](https://cocodataset.org/#home) dataset.
|
||||
|
||||
## Benchmark
|
||||
|
||||
We tested how much faster DeepCache accelerates [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) with 50 inference steps on an NVIDIA RTX A5000, using different configurations for resolution, batch size, cache interval (I), and cache branch (B).
|
||||
|
||||
| **Resolution** | **Batch size** | **Original** | **DeepCache(I=3, B=0)** | **DeepCache(I=5, B=0)** | **DeepCache(I=5, B=1)** |
|
||||
|----------------|----------------|--------------|-------------------------|-------------------------|-------------------------|
|
||||
| 512| 8| 15.96| 6.88(2.32x)| 5.03(3.18x)| 7.27(2.20x)|
|
||||
| | 4| 8.39| 3.60(2.33x)| 2.62(3.21x)| 3.75(2.24x)|
|
||||
| | 1| 2.61| 1.12(2.33x)| 0.81(3.24x)| 1.11(2.35x)|
|
||||
| 768| 8| 43.58| 18.99(2.29x)| 13.96(3.12x)| 21.27(2.05x)|
|
||||
| | 4| 22.24| 9.67(2.30x)| 7.10(3.13x)| 10.74(2.07x)|
|
||||
| | 1| 6.33| 2.72(2.33x)| 1.97(3.21x)| 2.98(2.12x)|
|
||||
| 1024| 8| 101.95| 45.57(2.24x)| 33.72(3.02x)| 53.00(1.92x)|
|
||||
| | 4| 49.25| 21.86(2.25x)| 16.19(3.04x)| 25.78(1.91x)|
|
||||
| | 1| 13.83| 6.07(2.28x)| 4.43(3.12x)| 7.15(1.93x)|
|
||||
@@ -12,17 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Accelerate inference of text-to-image diffusion models
|
||||
|
||||
Diffusion models are known to be slower than their counter parts, GANs, because of the iterative and sequential reverse diffusion process. Recent works try to address limitation with:
|
||||
Diffusion models are slower than their GAN counterparts because of the iterative and sequential reverse diffusion process. There are several techniques that can address this limitation such as progressive timestep distillation ([LCM LoRA](../using-diffusers/inference_with_lcm_lora)), model compression ([SSD-1B](https://huggingface.co/segmind/SSD-1B)), and reusing adjacent features of the denoiser ([DeepCache](../optimization/deepcache)).
|
||||
|
||||
* progressive timestep distillation (such as [LCM LoRA](../using-diffusers/inference_with_lcm_lora.md))
|
||||
* model compression (such as [SSD-1B](https://huggingface.co/segmind/SSD-1B))
|
||||
* reusing adjacent features of the denoiser (such as [DeepCache](https://github.com/horseee/DeepCache))
|
||||
However, you don't necessarily need to use these techniques to speed up inference. With PyTorch 2 alone, you can accelerate the inference latency of text-to-image diffusion pipelines by up to 3x. This tutorial will show you how to progressively apply the optimizations found in PyTorch 2 to reduce inference latency. You'll use the [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl) pipeline in this tutorial, but these techniques are applicable to other text-to-image diffusion pipelines too.
|
||||
|
||||
In this tutorial, we focus on leveraging the power of PyTorch 2 to accelerate the inference latency of text-to-image diffusion pipeline, instead. We will use [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl.md) as a case study, but the techniques we will discuss should extend to other text-to-image diffusion pipelines.
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you're on the latest version of `diffusers`:
|
||||
Make sure you're using the latest version of Diffusers:
|
||||
|
||||
```bash
|
||||
pip install -U diffusers
|
||||
@@ -34,15 +28,23 @@ Then upgrade the other required libraries too:
|
||||
pip install -U transformers accelerate peft
|
||||
```
|
||||
|
||||
To benefit from the fastest kernels, use PyTorch nightly. You can find the installation instructions [here](https://pytorch.org/).
|
||||
Install [PyTorch nightly](https://pytorch.org/) to benefit from the latest and fastest kernels:
|
||||
|
||||
To report the numbers shown below, we used an 80GB 400W A100 with its clock rate set to the maximum.
|
||||
```bash
|
||||
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||
```
|
||||
|
||||
_This tutorial doesn't present the benchmarking code and focuses on how to perform the optimizations, instead. For the full benchmarking code, refer to: [https://github.com/huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast)._
|
||||
<Tip>
|
||||
|
||||
The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. <br>
|
||||
|
||||
If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Baseline
|
||||
|
||||
Let's start with a baseline. Disable the use of a reduced precision and [`scaled_dot_product_attention`](../optimization/torch2.0.md):
|
||||
Let's start with a baseline. Disable reduced precision and the [`scaled_dot_product_attention` (SDPA)](../optimization/torch2.0#scaled-dot-product-attention) function which is automatically used by Diffusers:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
@@ -52,7 +54,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
# Run the attention ops without SDPA.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
@@ -60,27 +62,29 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
This takes 7.36 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_0.png" width=500>
|
||||
This default setup takes 7.36 seconds.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_0.png" width=500>
|
||||
</div>
|
||||
|
||||
## Running inference in bfloat16
|
||||
## bfloat16
|
||||
|
||||
Enable the first optimization: use a reduced precision to run the inference.
|
||||
Enable the first optimization, reduced precision or more specifically bfloat16. There are several benefits of using reduced precision:
|
||||
|
||||
* Using a reduced numerical precision (such as float16 or bfloat16) for inference doesn’t affect the generation quality but significantly improves latency.
|
||||
* The benefits of using bfloat16 compared to float16 are hardware dependent, but modern GPUs tend to favor bfloat16.
|
||||
* bfloat16 is much more resilient when used with quantization compared to float16, but more recent versions of the quantization library ([torchao](https://github.com/pytorch-labs/ao)) we used don't have numerical issues with float16.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
# Run the attention ops without SDPA.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
@@ -88,49 +92,45 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_1.png" width=500>
|
||||
bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_1.png" width=500>
|
||||
</div>
|
||||
|
||||
**Why bfloat16?**
|
||||
<Tip>
|
||||
|
||||
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
|
||||
* The benefits of using the bfloat16 numerical precision as compared to float16 are hardware-dependent. Modern generations of GPUs tend to favor bfloat16.
|
||||
* Furthermore, in our experiments, we bfloat16 to be much more resilient when used with quantization in comparison to float16.
|
||||
In our later experiments with float16, recent versions of torchao do not incur numerical problems from float16.
|
||||
|
||||
We have a [dedicated guide](../optimization/fp16.md) for running inference in a reduced precision.
|
||||
</Tip>
|
||||
|
||||
## Running attention efficiently
|
||||
Take a look at the [Speed up inference](../optimization/fp16) guide to learn more about running inference with reduced precision.
|
||||
|
||||
Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0.md), we can run them efficiently.
|
||||
## SDPA
|
||||
|
||||
Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0#scaled-dot-product-attention) function, it is a lot more efficient. This function is used by default in Diffusers so you don't need to make any changes to the code.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`scaled_dot_product_attention` improves the latency from 4.63 seconds to 3.31 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_2.png" width=500>
|
||||
Scaled dot product attention improves the latency from 4.63 seconds to 3.31 seconds.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_2.png" width=500>
|
||||
</div>
|
||||
|
||||
## Use faster kernels with torch.compile
|
||||
## torch.compile
|
||||
|
||||
Compile the UNet and the VAE to benefit from the faster kernels. First, configure a few compiler flags:
|
||||
PyTorch 2 includes `torch.compile` which uses fast and optimized kernels. In Diffusers, the UNet and VAE are usually compiled because these are the most compute-intensive modules. First, configure a few compiler flags (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options):
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
@@ -142,16 +142,14 @@ torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
```
|
||||
|
||||
For the full list of compiler flags, refer to [this file](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py).
|
||||
|
||||
It is also important to change the memory layout of the UNet and the VAE to “channels_last” when compiling them. This ensures maximum speed:
|
||||
It is also important to change the UNet and VAE's memory layout to "channels_last" when compiling them to ensure maximum speed.
|
||||
|
||||
```python
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Then, compile and perform inference:
|
||||
Now compile and perform inference:
|
||||
|
||||
```python
|
||||
# Compile the UNet and VAE.
|
||||
@@ -160,57 +158,65 @@ pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# First call to `pipe` will be slow, subsequent ones will be faster.
|
||||
# First call to `pipe` is slow, subsequent ones are faster.
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`torch.compile` offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of `torch.compile`.
|
||||
`torch.compile` offers different backends and modes. For maximum inference speed, use "max-autotune" for the inductor backend. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. CUDA graphs greatly reduces the overhead of launching GPU operations by using a mechanism to launch multiple GPU operations through a single CPU operation.
|
||||
|
||||
Using SDPA attention and compiling both the UNet and VAE reduces the latency from 3.31 seconds to 2.54 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
||||
Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3.31 seconds to 2.54 seconds.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
||||
</div>
|
||||
|
||||
## Combine the projection matrices of attention
|
||||
### Prevent graph breaks
|
||||
|
||||
Both the UNet and the VAE used in SDXL make use of Transformer-like blocks. A Transformer block consists of attention blocks and feed-forward blocks.
|
||||
Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
|
||||
|
||||
In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. In the naive implementation, these projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one shot. This increases the size of the matmuls of the input projections and improves the impact of quantization (to be discussed next).
|
||||
```diff
|
||||
- latents = unet(
|
||||
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
|
||||
-).sample
|
||||
|
||||
Enabling this kind of computation in Diffusers just takes a single line of code:
|
||||
+ latents = unet(
|
||||
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
|
||||
+)[0]
|
||||
```
|
||||
|
||||
### Remove GPU sync after compilation
|
||||
|
||||
During the iterative reverse diffusion process, the `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476) which when placed on the GPU, causes a communication sync between the CPU and GPU. This introduces latency and it becomes more evident when the denoiser has already been compiled.
|
||||
|
||||
But if the `sigmas` array always [stays on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240), the CPU and GPU sync doesn’t occur and you don't get any latency. In general, any CPU and GPU communication sync should be none or be kept to a bare minimum because it can impact inference latency.
|
||||
|
||||
## Combine the attention block's projection matrices
|
||||
|
||||
The UNet and VAE in SDXL use Transformer-like blocks which consists of attention blocks and feed-forward blocks.
|
||||
|
||||
In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. These projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one step. This increases the size of the matrix multiplications of the input projections and improves the impact of quantization.
|
||||
|
||||
You can combine the projection matrices with just a single line of code:
|
||||
|
||||
```python
|
||||
pipe.fuse_qkv_projections()
|
||||
```
|
||||
|
||||
It provides a minor boost from 2.54 seconds to 2.52 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_4.png" width=500>
|
||||
This provides a minor improvement from 2.54 seconds to 2.52 seconds.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_4.png" width=500>
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Support for `fuse_qkv_projections()` is limited and experimental. As such, it's not available for many non-SD pipelines such as [Kandinsky](../using-diffusers/kandinsky.md). You can refer to [this PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to support this kind of computation.
|
||||
Support for [`~StableDiffusionXLPipeline.fuse_qkv_projections`] is limited and experimental. It's not available for many non-Stable Diffusion pipelines such as [Kandinsky](../using-diffusers/kandinsky). You can refer to this [PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to enable this for the other pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
Aapply [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to both the UNet and the VAE. This is because quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
Through experimentation, we found that certain linear layers in the UNet and the VAE don’t benefit from dynamic int8 quantization. You can check out the full code for filtering those layers [here](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16) (referred to as `dynamic_quant_filter_fn` below).
|
||||
|
||||
</Tip>
|
||||
|
||||
You will leverage the ultra-lightweight pure PyTorch library [torchao](https://github.com/pytorch-labs/ao) to use its user-friendly APIs for quantization.
|
||||
You can also use the ultra-lightweight PyTorch quantization library, [torchao](https://github.com/pytorch-labs/ao) (commit SHA `54bcd5a10d0abbe7b0c045052029257099f83fd9`), to apply [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE. Quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.
|
||||
|
||||
First, configure all the compiler tags:
|
||||
|
||||
@@ -227,7 +233,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
torch._inductor.config.use_mixed_mm = True
|
||||
```
|
||||
|
||||
Define the filtering functions:
|
||||
Certain linear layers in the UNet and VAE don’t benefit from dynamic int8 quantization. You can filter out those layers with the [`dynamic_quant_filter_fn`](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16) shown below.
|
||||
|
||||
```python
|
||||
def dynamic_quant_filter_fn(mod, *args):
|
||||
@@ -265,12 +271,12 @@ def conv_filter_fn(mod, *args):
|
||||
)
|
||||
```
|
||||
|
||||
Then apply all the optimizations discussed so far:
|
||||
Finally, apply all the optimizations discussed so far:
|
||||
|
||||
```python
|
||||
# SDPA + bfloat16.
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Combine attention projection matrices.
|
||||
@@ -281,7 +287,7 @@ pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Since this quantization support is limited to linear layers only, we also turn suitable pointwise convolution layers into linear layers to maximize the benefit.
|
||||
Since dynamic quantization is only limited to the linear layers, convert the appropriate pointwise convolution layers into linear layers to maximize its benefit.
|
||||
|
||||
```python
|
||||
from torchao import swap_conv2d_1x1_to_linear
|
||||
@@ -311,8 +317,6 @@ image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
|
||||
Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
|
||||
</div>
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
</div>
|
||||
|
||||
@@ -64,20 +64,16 @@ With callbacks, you can implement features such as dynamic CFG without having to
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Using Callbacks to interrupt the Diffusion Process
|
||||
|
||||
The following Pipelines support interrupting the diffusion process via callback
|
||||
|
||||
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
|
||||
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
|
||||
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
|
||||
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
## Interrupt the diffusion process
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
<Tip>
|
||||
|
||||
The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
</Tip>
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
@@ -429,7 +429,7 @@ image = pipe(
|
||||
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||
### MultiControlNet
|
||||
## MultiControlNet
|
||||
|
||||
<Tip>
|
||||
|
||||
|
||||
@@ -77,12 +77,42 @@ Throughout this guide, the mask image is provided in all of the code examples fo
|
||||
Upload a base image to inpaint on and use the sketch tool to draw a mask. Once you're done, click **Run** to generate and download the mask image.
|
||||
|
||||
<iframe
|
||||
src="https://stevhliu-inpaint-mask-maker.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
src="https://stevhliu-inpaint-mask-maker.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
### Mask blur
|
||||
|
||||
The [`~VaeImageProcessor.blur`] method provides an option for how to blend the original image and inpaint area. The amount of blur is determined by the `blur_factor` parameter. Increasing the `blur_factor` increases the amount of blur applied to the mask edges, softening the transition between the original image and inpaint area. A low or zero `blur_factor` preserves the sharper edges of the mask.
|
||||
|
||||
To use this, create a blurred mask with the image processor.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png")
|
||||
blurred_mask = pipeline.mask_processor.blur(mask, blur_factor=33)
|
||||
blurred_mask
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask with no blur</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mask_blurred.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask with blur applied</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Popular models
|
||||
|
||||
[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
|
||||
@@ -318,7 +348,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
|
||||
The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.
|
||||
|
||||
If preserving the unmasked area is important for your task, you can use the code below to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
|
||||
If preserving the unmasked area is important for your task, you can use the [`VaeImageProcessor.apply_overlay`] method to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
@@ -345,18 +375,7 @@ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
repainted_image.save("repainted_image.png")
|
||||
|
||||
# Convert mask to grayscale NumPy array
|
||||
mask_image_arr = np.array(mask_image.convert("L"))
|
||||
# Add a channel dimension to the end of the grayscale mask
|
||||
mask_image_arr = mask_image_arr[:, :, None]
|
||||
# Binarize the mask: 1s correspond to the pixels which are repainted
|
||||
mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
|
||||
mask_image_arr[mask_image_arr < 0.5] = 0
|
||||
mask_image_arr[mask_image_arr >= 0.5] = 1
|
||||
|
||||
# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
|
||||
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
|
||||
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
|
||||
unmasked_unchanged_image = pipeline.image_processor.apply_overlay(mask_image, init_image, repainted_image)
|
||||
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
|
||||
make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
|
||||
```
|
||||
@@ -486,6 +505,39 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
### Padding mask crop
|
||||
|
||||
A method for increasing the inpainting image quality is to use the [`padding_mask_crop`](https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline.__call__.padding_mask_crop) parameter. When enabled, this option crops the masked area with some user-specified padding and it'll also crop the same area from the original image. Both the image and mask are upscaled to a higher resolution for inpainting, and then overlaid on the original image. This is a quick and easy way to improve image quality without using a separate pipeline like [`StableDiffusionUpscalePipeline`].
|
||||
|
||||
Add the `padding_mask_crop` parameter to the pipeline call and set it to the desired padding value.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
generator = torch.Generator(device='cuda').manual_seed(0)
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
base = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png")
|
||||
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png")
|
||||
|
||||
image = pipeline("boat", image=base, mask_image=mask, strength=0.75, generator=generator, padding_mask_crop=32).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/baseline_inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default inpaint image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/padding_mask_crop_inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">inpaint image with `padding_mask_crop` enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Chained inpainting pipelines
|
||||
|
||||
[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
|
||||
|
||||
@@ -344,7 +344,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a
|
||||
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
import torch
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
|
||||
@@ -26,7 +26,7 @@ Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0
|
||||
#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -23,7 +23,7 @@ Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install -q diffusers transformers accelerate omegaconf
|
||||
#!pip install -q diffusers transformers accelerate
|
||||
```
|
||||
|
||||
## Load model checkpoints
|
||||
|
||||
@@ -14,9 +14,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
|
||||
[Stable Video Diffusion (SVD)](https://huggingface.co/papers/2311.15127) is a powerful image-to-video generation model that can generate 2-4 second high resolution (576x1024) videos conditioned on an input image.
|
||||
|
||||
This guide will show you how to use SVD to short generate videos from images.
|
||||
This guide will show you how to use SVD to generate short videos from images.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
@@ -24,13 +24,9 @@ Before you begin, make sure you have the following libraries installed:
|
||||
!pip install -q -U diffusers transformers accelerate
|
||||
```
|
||||
|
||||
## Image to Video Generation
|
||||
The are two variants of this model, [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The SVD checkpoint is trained to generate 14 frames and the SVD-XT checkpoint is further finetuned to generate 25 frames.
|
||||
|
||||
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
|
||||
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
|
||||
finetuned to generate 25 frames.
|
||||
|
||||
We will use the `svd-xt` checkpoint for this guide.
|
||||
You'll use the SVD-XT checkpoint for this guide.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -53,27 +49,20 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video controls width="1024" height="576">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.webm" type="video/webm" />
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"source image of a rocket"</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket.gif"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"generated video from source image"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
| **Source Image** | **Video** |
|
||||
|:------------:|:-----:|
|
||||
|  |  |
|
||||
## torch.compile
|
||||
|
||||
|
||||
<Tip>
|
||||
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
|
||||
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
|
||||
|
||||
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
|
||||
</Tip>
|
||||
|
||||
|
||||
### Torch.compile
|
||||
|
||||
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
|
||||
You can gain a 20-25% speedup at the expense of slightly increased memory by [compiling](../optimization/torch2.0#torchcompile) the UNet.
|
||||
|
||||
```diff
|
||||
- pipe.enable_model_cpu_offload()
|
||||
@@ -81,37 +70,33 @@ You can achieve a 20-25% speed-up at the expense of slightly increased memory by
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### Low-memory
|
||||
## Reduce memory usage
|
||||
|
||||
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
|
||||
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
|
||||
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
|
||||
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
|
||||
Video generation is very memory intensive because you're essentially generating `num_frames` all at once, similar to text-to-image generation with a high batch size. To reduce the memory requirement, there are multiple options that trade-off inference speed for lower memory requirement:
|
||||
|
||||
You can enable them as follows:
|
||||
- enable model offloading: each component of the pipeline is offloaded to the CPU once it's not needed anymore.
|
||||
- enable feed-forward chunking: the feed-forward layer runs in a loop instead of running a single feed-forward with a huge batch size.
|
||||
- reduce `decode_chunk_size`: the VAE decodes frames in chunks instead of decoding them all together. Setting `decode_chunk_size=1` decodes one frame at a time and uses the least amount of memory (we recommend adjusting this value based on your GPU memory) but the video might have some flickering.
|
||||
|
||||
```diff
|
||||
-pipe.enable_model_cpu_offload()
|
||||
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
+pipe.enable_model_cpu_offload()
|
||||
+pipe.unet.enable_forward_chunking()
|
||||
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
|
||||
- pipe.enable_model_cpu_offload()
|
||||
- frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
+ pipe.enable_model_cpu_offload()
|
||||
+ pipe.unet.enable_forward_chunking()
|
||||
+ frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
|
||||
```
|
||||
|
||||
Using all these tricks togethere should lower the memory requirement to less than 8GB VRAM.
|
||||
|
||||
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
|
||||
## Micro-conditioning
|
||||
|
||||
### Micro-conditioning
|
||||
Stable Diffusion Video also accepts micro-conditioning, in addition to the conditioning image, which allows more control over the generated video:
|
||||
|
||||
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
|
||||
It accepts the following arguments:
|
||||
|
||||
- `fps`: The frames per second of the generated video.
|
||||
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
|
||||
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
|
||||
|
||||
Here is an example of using micro-conditioning to generate a video with more motion.
|
||||
- `fps`: the frames per second of the generated video.
|
||||
- `motion_bucket_id`: the motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id increases the motion of the generated video.
|
||||
- `noise_aug_strength`: the amount of noise added to the conditioning image. The higher the values the less the video resembles the conditioning image. Increasing this value also increases the motion of the generated video.
|
||||
|
||||
For example, to generate a video with more motion, use the `motion_bucket_id` and `noise_aug_strength` micro-conditioning parameters:
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -134,4 +119,3 @@ export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -2,9 +2,13 @@
|
||||
- local: index
|
||||
title: 🧨 Diffusers
|
||||
- local: quicktour
|
||||
title: 簡単な案内
|
||||
title: クイックツアー
|
||||
- local: stable_diffusion
|
||||
title: 効果的で効率的な拡散モデル
|
||||
title: 有効で効率の良い拡散モデル
|
||||
- local: installation
|
||||
title: インストール
|
||||
title: はじめに
|
||||
title: はじめに
|
||||
- sections:
|
||||
- local: tutorials/tutorial_overview
|
||||
title: 概要
|
||||
title: チュートリアル
|
||||
+8
-59
@@ -18,82 +18,31 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Diffusers
|
||||
|
||||
🤗 Diffusers は、画像や音声、さらには分子の3D構造を生成するための、最先端の事前学習済みDiffusion Model(拡散モデル)を提供するライブラリです。シンプルな生成ソリューションをお探しの場合でも、独自の拡散モデルをトレーニングしたい場合でも、🤗 Diffusers はその両方をサポートするモジュール式のツールボックスです。我々のライブラリは、[性能より使いやすさ](conceptual/philosophy#usability-over-performance)、[簡単よりシンプル](conceptual/philosophy#simple-over-easy)、[抽象化よりカスタマイズ性](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)に重点を置いて設計されています。
|
||||
🤗 Diffusers は、画像や音声、さらには分子の3D構造を生成するための、最先端の事前学習済みDiffusion Model(拡散モデル)を提供するライブラリです。シンプルな生成ソリューションをお探しの場合でも、独自の拡散モデルをトレーニングしたい場合でも、🤗 Diffusers はその両方をサポートするモジュール式のツールボックスです。私たちのライブラリは、[性能より使いやすさ](conceptual/philosophy#usability-over-performance)、[簡単よりシンプル](conceptual/philosophy#simple-over-easy)、[抽象化よりカスタマイズ性](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)に重点を置いて設計されています。
|
||||
|
||||
このライブラリには3つの主要コンポーネントがあります:
|
||||
|
||||
- 最先端の[拡散パイプライン](api/pipelines/overview)で数行のコードで生成が可能です。
|
||||
- 交換可能な[ノイズスケジューラ](api/schedulers/overview)で生成速度と品質のトレードオフのバランスをとれます。
|
||||
- 事前に訓練された[モデル](api/models)は、ビルディングブロックとして使用することができ、スケジューラと組み合わせることで、独自のエンドツーエンドの拡散システムを作成することができます。
|
||||
- 数行のコードで推論可能な最先端の[拡散パイプライン](api/pipelines/overview)。Diffusersには多くのパイプラインがあります。利用可能なパイプラインを網羅したリストと、それらが解決するタスクについては、パイプラインの[概要](https://huggingface.co/docs/diffusers/api/pipelines/overview)の表をご覧ください。
|
||||
- 生成速度と品質のトレードオフのバランスを取る交換可能な[ノイズスケジューラ](api/schedulers/overview)
|
||||
- ビルディングブロックとして使用することができ、スケジューラと組み合わせることで、エンドツーエンドの拡散モデルを構築可能な事前学習済み[モデル](api/models)
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./tutorials/tutorial_overview"
|
||||
><div class="w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">チュートリアル</div>
|
||||
<p class="text-gray-700">出力の生成、独自の拡散システムの構築、拡散モデルのトレーニングを開始するために必要な基本的なスキルを学ぶことができます。初めて🤗Diffusersを使用する場合は、ここから始めることをお勧めします!</p>
|
||||
<p class="text-gray-700">出力の生成、独自の拡散システムの構築、拡散モデルのトレーニングを開始するために必要な基本的なスキルを学ぶことができます。初めて 🤗Diffusersを使用する場合は、ここから始めることをおすすめします!</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./using-diffusers/loading_overview"
|
||||
><div class="w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">ガイド</div>
|
||||
<p class="text-gray-700">パイプライン、モデル、スケジューラのロードに役立つ実践的なガイドです。また、特定のタスクにパイプラインを使用する方法、出力の生成方法を制御する方法、生成速度を最適化する方法、さまざまなトレーニング手法についても学ぶことができます。</p>
|
||||
<p class="text-gray-700">パイプライン、モデル、スケジューラの読み込みに役立つ実践的なガイドです。また、特定のタスクにパイプラインを使用する方法、出力の生成方法を制御する方法、生成速度を最適化する方法、さまざまなトレーニング手法についても学ぶことができます。</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./conceptual/philosophy"
|
||||
><div class="w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Conceptual guides</div>
|
||||
<p class="text-gray-700">ライブラリがなぜこのように設計されたのかを理解し、ライブラリを利用する際の倫理的ガイドラインや安全対策について詳しく学べます。</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./api/models/overview"
|
||||
><div class="w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Reference</div>
|
||||
><div class="w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">リファレンス</div>
|
||||
<p class="text-gray-700">🤗 Diffusersのクラスとメソッドがどのように機能するかについての技術的な説明です。</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Supported pipelines
|
||||
|
||||
| Pipeline | Paper/Repository | Tasks |
|
||||
|---|---|:---:|
|
||||
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
|
||||
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
|
||||
| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation |
|
||||
| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [paint_by_example](./api/pipelines/paint_by_example) | [Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
|
||||
| [pndm](./api/pipelines/pndm) | [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
|
||||
| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
|
||||
| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
|
||||
| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [MultiDiffusion](https://multidiffusion.github.io/) | Text-to-Panorama Generation |
|
||||
| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://arxiv.org/abs/2211.09800) | Text-Guided Image Editing|
|
||||
| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [Zero-shot Image-to-Image Translation](https://pix2pixzero.github.io/) | Text-Guided Image Editing |
|
||||
| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://arxiv.org/abs/2301.13826) | Text-to-Image Generation |
|
||||
| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) | Text-to-Image Generation Unconditional Image Generation |
|
||||
| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation |
|
||||
| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [Stable Diffusion Latent Upscaler](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_model_editing](./api/pipelines/stable_diffusion/model_editing) | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://time-diffusion.github.io/) | Text-to-Image Model Editing |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Depth-Conditional Stable Diffusion](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [Safe Stable Diffusion](https://arxiv.org/abs/2211.05105) | Text-Guided Generation |
|
||||
| [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation |
|
||||
| [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation |
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation |
|
||||
| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
| [stable_diffusion_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) | Text to Image and Depth Generation |
|
||||
| [stable_diffusion_upscaler_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D-VR: Latent Diffusion Model for 3D VR](https://arxiv.org/pdf/2311.03226) | Image and Depth Upscaling |
|
||||
</div>
|
||||
@@ -0,0 +1,23 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Overview
|
||||
|
||||
ようこそ 🧨Diffusersへ!拡散モデル(diffusion models)や生成AIの初心者で、さらに学びたいのであれば、このチュートリアルが最適です。この初心者向けのチュートリアルは、拡散モデルについて丁寧に解説し、ライブラリの基礎(核となるコンポーネントと 🧨Diffusersの使用方法)を理解することを目的としています。
|
||||
|
||||
まず、推論のためのパイプラインを使って、素早く生成する方法を学んでいきます。次に、独自の拡散システムを構築するためのモジュラーツールボックスとしてライブラリをどのように使えば良いかを理解するために、そのパイプラインを分解してみましょう。次のレッスンでは、あなたの欲しいものを生成できるように拡散モデルをトレーニングする方法を学びましょう。
|
||||
|
||||
このチュートリアルがすべて完了したら、ライブラリを自分で調べ、自分のプロジェクトやアプリケーションにどのように使えるかを知るために必要なスキルを身につけることができます。
|
||||
|
||||
そして、 [Discord](https://discord.com/invite/JfAtkvEtRb) や [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) でDiffusersコミュニティに参加してユーザーや開発者と繋がって協力していきましょう。
|
||||
|
||||
さあ、「拡散」をはじめていきましょう!🧨
|
||||
@@ -20,6 +20,7 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -37,9 +38,11 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -54,52 +57,26 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -125,10 +102,17 @@ def save_model_card(
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
|
||||
embeddings_filename = f"{repo_folder}_emb"
|
||||
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
||||
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
||||
if instance_prompt_webui != embeddings_filename:
|
||||
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
|
||||
else:
|
||||
instance_prompt_sentence = ""
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
webui_example_pivotal = ""
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
@@ -137,11 +121,16 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
|
||||
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
@@ -161,8 +150,6 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -175,9 +162,14 @@ widget:
|
||||
|
||||
### These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
## Trigger words
|
||||
## Download model
|
||||
|
||||
{trigger_str}
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
|
||||
- Place it on your `models/Lora` folder.
|
||||
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
|
||||
{webui_example_pivotal}
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
@@ -193,16 +185,12 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
|
||||
## Download model
|
||||
## Trigger words
|
||||
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
|
||||
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
|
||||
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
{trigger_str}
|
||||
|
||||
## Details
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
|
||||
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
|
||||
|
||||
@@ -1264,54 +1252,25 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
# in text encoder
|
||||
@@ -1335,6 +1294,11 @@ def main(args):
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1345,12 +1309,18 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1363,6 +1333,8 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors")
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
@@ -1372,27 +1344,44 @@ def main(args):
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -1407,6 +1396,19 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||
|
||||
@@ -1841,9 +1843,17 @@ def main(args):
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# if we're using prior preservation, we calc snr for instance loss only -
|
||||
# and hence only need timesteps corresponding to instance images
|
||||
snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
|
||||
else:
|
||||
snr_timesteps = timesteps
|
||||
|
||||
snr = compute_snr(noise_scheduler, snr_timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
@@ -1997,13 +2007,17 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
@@ -2073,8 +2087,15 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
|
||||
@@ -21,6 +21,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
@@ -54,6 +55,10 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
|
||||
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
```py
|
||||
@@ -738,6 +743,49 @@ grid = image_grid(images, rows=2, cols=2)
|
||||
This example produces the following images:
|
||||

|
||||
|
||||
### GlueGen Stable Diffusion Pipeline
|
||||
GlueGen is a minimal adapter that allow alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.
|
||||
|
||||
Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main)
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
lm_model_id = "xlm-roberta-large"
|
||||
token_max_length = 77
|
||||
|
||||
text_encoder = AutoModel.from_pretrained(lm_model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(lm_model_id, model_max_length=token_max_length, use_fast=False)
|
||||
|
||||
tensor_norm = torch.Tensor([[43.8203],[28.3668],[27.9345],[28.0084],[28.2958],[28.2576],[28.3373],[28.2695],[28.4097],[28.2790],[28.2825],[28.2807],[28.2775],[28.2708],[28.2682],[28.2624],[28.2589],[28.2611],[28.2616],[28.2639],[28.2613],[28.2566],[28.2615],[28.2665],[28.2799],[28.2885],[28.2852],[28.2863],[28.2780],[28.2818],[28.2764],[28.2532],[28.2412],[28.2336],[28.2514],[28.2734],[28.2763],[28.2977],[28.2971],[28.2948],[28.2818],[28.2676],[28.2831],[28.2890],[28.2979],[28.2999],[28.3117],[28.3363],[28.3554],[28.3626],[28.3589],[28.3597],[28.3543],[28.3660],[28.3731],[28.3717],[28.3812],[28.3753],[28.3810],[28.3777],[28.3693],[28.3713],[28.3670],[28.3691],[28.3679],[28.3624],[28.3703],[28.3703],[28.3720],[28.3594],[28.3576],[28.3562],[28.3438],[28.3376],[28.3389],[28.3433],[28.3191]])
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
custom_pipeline="gluegen"
|
||||
).to(device)
|
||||
pipeline.load_language_adapter("gluenet_French_clip_overnorm_over3_noln.ckpt", num_token=token_max_length, dim=1024, dim_out=768, tensor_norm=tensor_norm)
|
||||
|
||||
prompt = "une voiture sur la plage"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image.save("gluegen_output_fr.png")
|
||||
```
|
||||
Which will produce:
|
||||
|
||||

|
||||
|
||||
### Image to Image Inpainting Stable Diffusion
|
||||
|
||||
Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.
|
||||
@@ -2942,7 +2990,7 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="pipeline_animatediff_controlnet",
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
@@ -2958,7 +3006,7 @@ result = pipe(
|
||||
width=512,
|
||||
height=768,
|
||||
conditioning_frames=conditioning_frames,
|
||||
num_inference_steps=12,
|
||||
num_inference_steps=20,
|
||||
).frames[0]
|
||||
|
||||
from diffusers.utils import export_to_gif
|
||||
@@ -2981,7 +3029,82 @@ export_to_gif(result.frames[0], "result.gif")
|
||||
<td align=center><img src="https://github.com/huggingface/diffusers/assets/72266394/eb7d2952-72e4-44fa-b664-077c79b4fc70" alt="gif-2"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
You can also use multiple controlnets at once!
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.schedulers import DPMSolverMultistepScheduler
|
||||
from PIL import Image
|
||||
|
||||
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
|
||||
adapter = MotionAdapter.from_pretrained(motion_id)
|
||||
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
|
||||
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
||||
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
motion_adapter=adapter,
|
||||
controlnet=[controlnet1, controlnet2],
|
||||
vae=vae,
|
||||
custom_pipeline="pipeline_animatediff_controlnet",
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
def load_video(file_path: str):
|
||||
images = []
|
||||
|
||||
if file_path.startswith(('http://', 'https://')):
|
||||
# If the file_path is a URL
|
||||
response = requests.get(file_path)
|
||||
response.raise_for_status()
|
||||
content = BytesIO(response.content)
|
||||
vid = imageio.get_reader(content)
|
||||
else:
|
||||
# Assuming it's a local file path
|
||||
vid = imageio.get_reader(file_path)
|
||||
|
||||
for frame in vid:
|
||||
pil_image = Image.fromarray(frame)
|
||||
images.append(pil_image)
|
||||
|
||||
return images
|
||||
|
||||
video = load_video("dance.gif")
|
||||
|
||||
# You need to install it using `pip install controlnet_aux`
|
||||
from controlnet_aux.processor import Processor
|
||||
|
||||
p1 = Processor("openpose_full")
|
||||
cn1 = [p1(frame) for frame in video]
|
||||
|
||||
p2 = Processor("canny")
|
||||
cn2 = [p2(frame) for frame in video]
|
||||
|
||||
prompt = "astronaut in space, dancing"
|
||||
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
|
||||
result = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=512,
|
||||
height=768,
|
||||
conditioning_frames=[cn1, cn2],
|
||||
num_inference_steps=20,
|
||||
)
|
||||
|
||||
from diffusers.utils import export_to_gif
|
||||
export_to_gif(result.frames[0], "result.gif")
|
||||
```
|
||||
|
||||
### DemoFusion
|
||||
|
||||
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
|
||||
The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
|
||||
- `view_batch_size` (`int`, defaults to 16):
|
||||
@@ -3102,3 +3225,245 @@ output_image = PIL.Image.fromarray(output)
|
||||
output_image.save("./output.png")
|
||||
|
||||
```
|
||||
|
||||
### Null-Text Inversion pipeline
|
||||
|
||||
This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.
|
||||
* Reference paper
|
||||
```@article{hertz2022prompt,
|
||||
title={Prompt-to-prompt image editing with cross attention control},
|
||||
author={Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
|
||||
booktitle={arXiv preprint arXiv:2208.01626},
|
||||
year={2022}
|
||||
```}
|
||||
|
||||
```py
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from examples.community.pipeline_null_text_inversion import NullTextPipeline
|
||||
import torch
|
||||
|
||||
# Load the pipeline
|
||||
device = "cuda"
|
||||
# Provide invert_prompt and the image for null-text optimization.
|
||||
invert_prompt = "A lying cat"
|
||||
input_image = "siamese.jpg"
|
||||
steps = 50
|
||||
|
||||
# Provide prompt used for generation. Same if reconstruction
|
||||
prompt = "A lying cat"
|
||||
# or different if editing.
|
||||
prompt = "A lying dog"
|
||||
|
||||
#Float32 is essential to a well optimization
|
||||
model_path = "runwayml/stable-diffusion-v1-5"
|
||||
scheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear")
|
||||
pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, torch_dtype=torch.float32).to(device)
|
||||
|
||||
#Saves the inverted_latent to save time
|
||||
inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps)
|
||||
pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg")
|
||||
```
|
||||
### Rerender_A_Video
|
||||
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender_A_Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`:
|
||||
|
||||
```py
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
```
|
||||
|
||||
After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoencoderKL, DDIMScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
def video_to_frame(video_path: str, interval: int):
|
||||
vidcap = cv2.VideoCapture(video_path)
|
||||
success = True
|
||||
|
||||
count = 0
|
||||
res = []
|
||||
while success:
|
||||
count += 1
|
||||
success, image = vidcap.read()
|
||||
if count % interval != 1:
|
||||
continue
|
||||
if image is not None:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
res.append(image)
|
||||
|
||||
vidcap.release()
|
||||
return res
|
||||
|
||||
input_video_path = 'path/to/video'
|
||||
input_interval = 10
|
||||
frames = video_to_frame(
|
||||
input_video_path, input_interval)
|
||||
|
||||
control_frames = []
|
||||
# get canny image
|
||||
for frame in frames:
|
||||
np_image = cv2.Canny(frame, 50, 100)
|
||||
np_image = np_image[:, :, None]
|
||||
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
|
||||
canny_image = Image.fromarray(np_image)
|
||||
control_frames.append(canny_image)
|
||||
|
||||
# You can use any ControlNet here
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-canny").to('cuda')
|
||||
|
||||
# You can use any fintuned SD here
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, custom_pipeline='rerender_a_video').to('cuda')
|
||||
|
||||
# Optional: you can download vae-ft-mse-840000-ema-pruned.ckpt to enhance the results
|
||||
# pipe.vae = AutoencoderKL.from_single_file(
|
||||
# "path/to/vae-ft-mse-840000-ema-pruned.ckpt").to('cuda')
|
||||
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
output_frames = pipe(
|
||||
"a beautiful woman in CG style, best quality, extremely detailed",
|
||||
|
||||
frames,
|
||||
control_frames,
|
||||
num_inference_steps=20,
|
||||
strength=0.75,
|
||||
controlnet_conditioning_scale=0.7,
|
||||
generator=generator,
|
||||
warp_start=0.0,
|
||||
warp_end=0.1,
|
||||
mask_start=0.5,
|
||||
mask_end=0.8,
|
||||
mask_strength=0.5,
|
||||
negative_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
||||
).frames
|
||||
|
||||
export_to_video(
|
||||
output_frames, "/path/to/video.mp4", 5)
|
||||
```
|
||||
|
||||
### StyleAligned Pipeline
|
||||
|
||||
This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133).
|
||||
|
||||
> Large-scale Text-to-Image (T2I) models have rapidly gained prominence across creative fields, generating visually compelling outputs from textual prompts. However, controlling these models to ensure consistent style remains challenging, with existing methods necessitating fine-tuning and manual intervention to disentangle content and style. In this paper, we introduce StyleAligned, a novel technique designed to establish style alignment among a series of generated images. By employing minimal `attention sharing' during the diffusion process, our method maintains style consistency across images within T2I models. This approach allows for the creation of style-consistent images using a reference style through a straightforward inversion operation. Our method's evaluation across diverse styles and text prompts demonstrates high-quality synthesis and fidelity, underscoring its efficacy in achieving consistent style across various inputs.
|
||||
|
||||
```python
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", custom_pipeline="pipeline_sdxl_style_aligned")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Enable memory saving techniques
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = [
|
||||
"a toy train. macro photo. 3d game asset",
|
||||
"a toy airplane. macro photo. 3d game asset",
|
||||
"a toy bicycle. macro photo. 3d game asset",
|
||||
"a toy car. macro photo. 3d game asset",
|
||||
]
|
||||
negative_prompt = "low quality, worst quality, "
|
||||
|
||||
# Enable StyleAligned
|
||||
pipe.enable_style_aligned(
|
||||
share_group_norm=False,
|
||||
share_layer_norm=False,
|
||||
share_attention=True,
|
||||
adain_queries=True,
|
||||
adain_keys=True,
|
||||
adain_values=False,
|
||||
full_attention_share=False,
|
||||
shared_score_scale=1.0,
|
||||
shared_score_shift=0.0,
|
||||
only_self_level=0.0,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
images = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=2,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=10,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images
|
||||
|
||||
# Disable StyleAligned if you do not wish to use it anymore
|
||||
pipe.disable_style_aligned()
|
||||
```
|
||||
|
||||
### IP Adapter Face ID
|
||||
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
|
||||
You need to install `insightface` and all its requirements to use this model.
|
||||
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
|
||||
You have to disable PEFT BACKEND in order to load weights.
|
||||
|
||||
```py
|
||||
import diffusers
|
||||
diffusers.utils.USE_PEFT_BACKEND = False
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
import cv2
|
||||
import numpy as np
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/Realistic_Vision_V4.0_noVAE",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=noise_scheduler,
|
||||
vae=vae,
|
||||
custom_pipeline="ip_adapter_face_id"
|
||||
)
|
||||
pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin")
|
||||
pipeline.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(42)
|
||||
num_images=2
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
|
||||
|
||||
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
|
||||
faces = app.get(image)
|
||||
image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
||||
images = pipeline(
|
||||
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
|
||||
image_embeds=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
for i in range(num_images):
|
||||
images[i].save(f"c{i}.png")
|
||||
```
|
||||
|
||||
@@ -0,0 +1,865 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class TranslatorBase(nn.Module):
|
||||
def __init__(self, num_tok, dim, dim_out, mult=2):
|
||||
super().__init__()
|
||||
|
||||
self.dim_in = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.net_tok = nn.Sequential(
|
||||
nn.Linear(num_tok, int(num_tok * mult)),
|
||||
nn.LayerNorm(int(num_tok * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(num_tok * mult), int(num_tok * mult)),
|
||||
nn.LayerNorm(int(num_tok * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(num_tok * mult), num_tok),
|
||||
nn.LayerNorm(num_tok),
|
||||
)
|
||||
|
||||
self.net_sen = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mult)),
|
||||
nn.LayerNorm(int(dim * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mult), int(dim * mult)),
|
||||
nn.LayerNorm(int(dim * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mult), dim_out),
|
||||
nn.LayerNorm(dim_out),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.dim_in == self.dim_out:
|
||||
indentity_0 = x
|
||||
x = self.net_sen(x)
|
||||
x += indentity_0
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
indentity_1 = x
|
||||
x = self.net_tok(x)
|
||||
x += indentity_1
|
||||
x = x.transpose(1, 2)
|
||||
else:
|
||||
x = self.net_sen(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = self.net_tok(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TranslatorBaseNoLN(nn.Module):
|
||||
def __init__(self, num_tok, dim, dim_out, mult=2):
|
||||
super().__init__()
|
||||
|
||||
self.dim_in = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.net_tok = nn.Sequential(
|
||||
nn.Linear(num_tok, int(num_tok * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(num_tok * mult), int(num_tok * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(num_tok * mult), num_tok),
|
||||
)
|
||||
|
||||
self.net_sen = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mult), int(dim * mult)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mult), dim_out),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.dim_in == self.dim_out:
|
||||
indentity_0 = x
|
||||
x = self.net_sen(x)
|
||||
x += indentity_0
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
indentity_1 = x
|
||||
x = self.net_tok(x)
|
||||
x += indentity_1
|
||||
x = x.transpose(1, 2)
|
||||
else:
|
||||
x = self.net_sen(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = self.net_tok(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TranslatorNoLN(nn.Module):
|
||||
def __init__(self, num_tok, dim, dim_out, mult=2, depth=5):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([TranslatorBase(num_tok, dim, dim, mult=2) for d in range(depth)])
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
self.tail = TranslatorBaseNoLN(num_tok, dim, dim_out, mult=2)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x) + x
|
||||
x = self.gelu(x)
|
||||
|
||||
x = self.tail(x)
|
||||
return x
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class GlueGenStableDiffusionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
language_adapter: TranslatorNoLN = None,
|
||||
tensor_norm: torch.FloatTensor = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
language_adapter=language_adapter,
|
||||
tensor_norm=tensor_norm,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def load_language_adapter(
|
||||
self,
|
||||
model_path: str,
|
||||
num_token: int,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
tensor_norm: torch.FloatTensor,
|
||||
mult: int = 2,
|
||||
depth: int = 5,
|
||||
):
|
||||
device = self._execution_device
|
||||
self.tensor_norm = tensor_norm.to(device)
|
||||
self.language_adapter = TranslatorNoLN(num_tok=num_token, dim=dim, dim_out=dim_out, mult=mult, depth=depth).to(
|
||||
device
|
||||
)
|
||||
self.language_adapter.load_state_dict(torch.load(model_path))
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _adapt_language(self, prompt_embeds: torch.FloatTensor):
|
||||
prompt_embeds = prompt_embeds / 3
|
||||
prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2)
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
elif self.language_adapter is not None:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
# Run prompt language adapter
|
||||
if self.language_adapter is not None:
|
||||
prompt_embeds = self._adapt_language(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
# Run negative prompt language adapter
|
||||
if self.language_adapter is not None:
|
||||
negative_prompt_embeds = self._adapt_language(negative_prompt_embeds)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,14 +12,21 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -27,6 +34,7 @@ from diffusers.models.attention_processor import (
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
@@ -252,6 +260,7 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt_2: str = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights, no length limitation
|
||||
@@ -265,6 +274,7 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt_2 (str)
|
||||
num_images_per_prompt (int)
|
||||
device (torch.device)
|
||||
clip_skip (int)
|
||||
Returns:
|
||||
prompt_embeds (torch.Tensor)
|
||||
neg_prompt_embeds (torch.Tensor)
|
||||
@@ -277,17 +287,24 @@ def get_weighted_text_embeddings_sdxl(
|
||||
if neg_prompt_2:
|
||||
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
|
||||
|
||||
prompt_t1 = prompt_t2 = prompt
|
||||
neg_prompt_t1 = neg_prompt_t2 = neg_prompt
|
||||
|
||||
if isinstance(pipe, TextualInversionLoaderMixin):
|
||||
prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
|
||||
neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
|
||||
prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
|
||||
neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
|
||||
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
|
||||
# tokenizer 1
|
||||
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
||||
|
||||
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
||||
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
|
||||
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
|
||||
|
||||
# tokenizer 2
|
||||
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
|
||||
|
||||
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
|
||||
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
|
||||
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
|
||||
|
||||
# padding the shorter one for prompt set 1
|
||||
prompt_token_len = len(prompt_tokens)
|
||||
@@ -342,13 +359,19 @@ def get_weighted_text_embeddings_sdxl(
|
||||
|
||||
# use first text encoder
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
pooled_prompt_embeds = prompt_embeds_2[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
||||
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
||||
|
||||
@@ -521,19 +544,21 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
class SDXLLongPromptWeightingPipeline(
|
||||
DiffusionPipeline, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -554,12 +579,34 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
"tokenizer_2",
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
]
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"negative_add_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
@@ -569,6 +616,8 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
@@ -582,6 +631,8 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
@@ -661,6 +712,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -852,6 +904,31 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -884,6 +961,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -891,14 +969,19 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -947,6 +1030,95 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
|
||||
# get the original timestep using init_timestep
|
||||
if denoising_start is None:
|
||||
@@ -1241,6 +1413,35 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -1249,6 +1450,10 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@@ -1295,19 +1500,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -1384,6 +1592,8 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
@@ -1404,12 +1614,6 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -1433,6 +1637,18 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1441,6 +1657,23 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
@@ -1462,10 +1695,12 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
@@ -1480,13 +1715,16 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 3. Encode input prompt
|
||||
(cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
|
||||
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
|
||||
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
@@ -1496,7 +1734,11 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = get_weighted_text_embeddings_sdxl(
|
||||
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
neg_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
@@ -1576,7 +1818,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
else:
|
||||
latents, noise = latents
|
||||
|
||||
# 5.1. Prepare mask latent variables
|
||||
# 5.1 Prepare mask latent variables
|
||||
if mask is not None:
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
@@ -1590,7 +1832,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
# Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
@@ -1611,6 +1853,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else {}
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
@@ -1624,7 +1869,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
@@ -1671,7 +1916,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -1679,7 +1924,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids})
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
@@ -1691,11 +1936,11 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@@ -1718,6 +1963,21 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -1774,20 +2034,28 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for text-to-image.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -1804,19 +2072,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def img2img(
|
||||
@@ -1838,20 +2109,28 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for image-to-image.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -1870,19 +2149,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def inpaint(
|
||||
@@ -1906,20 +2188,28 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for inpainting.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -1940,19 +2230,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
|
||||
@@ -40,7 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.1.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... custom_pipeline="pipeline_animatediff_controlnet",
|
||||
... ).to(device="cuda", dtype=torch.float16)
|
||||
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
... )
|
||||
>>> pipe.enable_vae_slicing()
|
||||
|
||||
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... height=768,
|
||||
... conditioning_frames=conditioning_frames,
|
||||
... num_inference_steps=12,
|
||||
... ).frames[0]
|
||||
... )
|
||||
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
>>> export_to_gif(result.frames[0], "result.gif")
|
||||
@@ -151,7 +151,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
controlnet: Union[ControlNetModel, MultiControlNetModel],
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -166,6 +166,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -488,6 +491,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
@@ -557,31 +561,21 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if isinstance(image, list):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
if not isinstance(image, list):
|
||||
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
|
||||
if len(image) != num_frames:
|
||||
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
|
||||
elif (
|
||||
isinstance(self.controlnet, MultiControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
||||
):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
|
||||
# When `image` is a nested list:
|
||||
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
||||
elif any(isinstance(i, list) for i in image):
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for control_ in image:
|
||||
for image_ in control_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
if not isinstance(image, list) or not isinstance(image[0], list):
|
||||
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
|
||||
if len(image[0]) != num_frames:
|
||||
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
|
||||
if any(len(img) != len(image[0]) for img in image):
|
||||
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -913,6 +907,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
callback_steps=callback_steps,
|
||||
negative_prompt=negative_prompt,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
@@ -1000,9 +995,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
cond_prepared_frames.append(prepared_frame)
|
||||
|
||||
conditioning_frames = cond_prepared_frames
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as nnf
|
||||
from PIL import Image
|
||||
from torch.optim.adam import Adam
|
||||
from tqdm import tqdm
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps=None,
|
||||
device=None,
|
||||
timesteps=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class NullTextPipeline(StableDiffusionPipeline):
|
||||
def get_noise_pred(self, latents, t, context):
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
guidance_scale = 7.5
|
||||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
latents = self.prev_step(noise_pred, t, latents)
|
||||
return latents
|
||||
|
||||
def get_noise_pred_single(self, latents, t, context):
|
||||
noise_pred = self.unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
@torch.no_grad()
|
||||
def image2latent(self, image_path):
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = np.array(image)
|
||||
image = torch.from_numpy(image).float() / 127.5 - 1
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)
|
||||
latents = self.vae.encode(image)["latent_dist"].mean
|
||||
latents = latents * 0.18215
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(self, latents):
|
||||
latents = 1 / 0.18215 * latents.detach()
|
||||
image = self.vae.decode(latents)["sample"].detach()
|
||||
image = self.processor.postprocess(image, output_type="pil")[0]
|
||||
return image
|
||||
|
||||
def prev_step(self, model_output, timestep, sample):
|
||||
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = (
|
||||
self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
|
||||
prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
|
||||
return prev_sample
|
||||
|
||||
def next_step(self, model_output, timestep, sample):
|
||||
timestep, next_timestep = (
|
||||
min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999),
|
||||
timestep,
|
||||
)
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
|
||||
def null_optimization(self, latents, context, num_inner_steps, epsilon):
|
||||
uncond_embeddings, cond_embeddings = context.chunk(2)
|
||||
uncond_embeddings_list = []
|
||||
latent_cur = latents[-1]
|
||||
bar = tqdm(total=num_inner_steps * self.num_inference_steps)
|
||||
for i in range(self.num_inference_steps):
|
||||
uncond_embeddings = uncond_embeddings.clone().detach()
|
||||
uncond_embeddings.requires_grad = True
|
||||
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0))
|
||||
latent_prev = latents[len(latents) - i - 2]
|
||||
t = self.scheduler.timesteps[i]
|
||||
with torch.no_grad():
|
||||
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
|
||||
for j in range(num_inner_steps):
|
||||
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
|
||||
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond)
|
||||
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
|
||||
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss_item = loss.item()
|
||||
bar.update()
|
||||
if loss_item < epsilon + i * 2e-5:
|
||||
break
|
||||
for j in range(j + 1, num_inner_steps):
|
||||
bar.update()
|
||||
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
|
||||
with torch.no_grad():
|
||||
context = torch.cat([uncond_embeddings, cond_embeddings])
|
||||
latent_cur = self.get_noise_pred(latent_cur, t, context)
|
||||
bar.close()
|
||||
return uncond_embeddings_list
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion_loop(self, latent, context):
|
||||
self.scheduler.set_timesteps(self.num_inference_steps)
|
||||
_, cond_embeddings = context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
with torch.no_grad():
|
||||
for i in range(0, self.num_inference_steps):
|
||||
t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1]
|
||||
noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"]
|
||||
latent = self.next_step(noise_pred, t, latent)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
def get_context(self, prompt):
|
||||
uncond_input = self.tokenizer(
|
||||
[""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
text_input = self.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
return context
|
||||
|
||||
def invert(
|
||||
self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50
|
||||
):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
context = self.get_context(prompt)
|
||||
latent = self.image2latent(image_path)
|
||||
ddim_latents = self.ddim_inversion_loop(latent, context)
|
||||
if os.path.exists(image_path + ".pt"):
|
||||
uncond_embeddings = torch.load(image_path + ".pt")
|
||||
else:
|
||||
uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon)
|
||||
uncond_embeddings = torch.stack(uncond_embeddings, 0)
|
||||
torch.save(uncond_embeddings, image_path + ".pt")
|
||||
return ddim_latents[-1], uncond_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
uncond_embeddings,
|
||||
inverted_latent,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps=None,
|
||||
guidance_scale=7.5,
|
||||
negative_prompt=None,
|
||||
num_images_per_prompt=1,
|
||||
generator=None,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
output_type="pil",
|
||||
):
|
||||
self._guidance_scale = guidance_scale
|
||||
# 0. Default height and width to unet
|
||||
height = self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = self.unet.config.sample_size * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hook
|
||||
callback_steps = None
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
# 2. Define call parameter
|
||||
device = self._execution_device
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, _ = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
latents = inverted_latent
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])["sample"]
|
||||
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)["sample"]
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
progress_bar.update()
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
else:
|
||||
image = latents
|
||||
image = self.image_processor.postprocess(
|
||||
image, output_type=output_type, do_denormalize=[True] * image.shape[0]
|
||||
)
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -71,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -165,6 +166,7 @@ class SDText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 512,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -174,10 +176,12 @@ class SDText2ImageDataset:
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -353,8 +357,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -572,6 +577,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -710,6 +724,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -915,9 +973,10 @@ def main(args):
|
||||
)
|
||||
|
||||
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -932,7 +991,12 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
|
||||
@@ -1051,6 +1115,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1162,10 +1227,10 @@ def main(args):
|
||||
if vae.dtype != weight_dtype:
|
||||
vae.to(dtype=weight_dtype)
|
||||
|
||||
# encode pixel values with batch size of at most 32
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 32):
|
||||
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1181,9 +1246,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -51,7 +51,8 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -59,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -113,7 +114,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
if unet is None:
|
||||
raise ValueError("Must provide a `unet` when doing intermediate validation.")
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
state_dict = get_peft_model_state_dict(unet)
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
to_load = state_dict
|
||||
else:
|
||||
to_load = args.output_dir
|
||||
@@ -193,8 +194,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -396,6 +398,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -534,6 +545,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -776,10 +831,10 @@ def main(args):
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
lora_alpha=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -794,16 +849,19 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet.add_adapter(lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
@@ -819,7 +877,7 @@ def main(args):
|
||||
unet_ = accelerator.unwrap_model(unet)
|
||||
# also save the checkpoints in native `diffusers` format so that it can be easily
|
||||
# be independently loaded via `load_lora_weights()`.
|
||||
state_dict = get_peft_model_state_dict(unet_)
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
|
||||
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
|
||||
|
||||
for _, model in enumerate(models):
|
||||
@@ -929,7 +987,8 @@ def main(args):
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
|
||||
train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
|
||||
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
||||
@@ -1121,11 +1180,11 @@ def main(args):
|
||||
|
||||
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
pixel_values = pixel_values.to(dtype=vae.dtype)
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], args.encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1142,9 +1201,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
@@ -1184,7 +1247,7 @@ def main(args):
|
||||
# solver timestep.
|
||||
|
||||
# With the adapters disabled, the `unet` is the regular teacher model.
|
||||
unet.disable_adapters()
|
||||
accelerator.unwrap_model(unet).disable_adapters()
|
||||
with torch.no_grad():
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = unet(
|
||||
@@ -1248,7 +1311,7 @@ def main(args):
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
|
||||
|
||||
# re-enable unet adapters to turn the `unet` into a student unet.
|
||||
unet.enable_adapters()
|
||||
accelerator.unwrap_model(unet).enable_adapters()
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
@@ -1332,7 +1395,7 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -62,6 +62,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -77,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -171,6 +172,7 @@ class SDXLText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 1024,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -187,10 +189,12 @@ class SDXLText2ImageDataset:
|
||||
else:
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -340,8 +344,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -546,6 +551,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fix_crop_and_size",
|
||||
action="store_true",
|
||||
@@ -690,6 +704,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -929,9 +987,10 @@ def main(args):
|
||||
)
|
||||
|
||||
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -946,7 +1005,12 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
|
||||
@@ -1090,6 +1154,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1214,10 +1279,10 @@ def main(args):
|
||||
else:
|
||||
pixel_values = image
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 8):
|
||||
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1234,9 +1299,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -60,6 +60,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -70,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -147,6 +148,7 @@ class SDText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 512,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -156,10 +158,12 @@ class SDText2ImageDataset:
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -330,8 +334,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -549,6 +554,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -690,6 +704,26 @@ def parse_args():
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -887,10 +921,12 @@ def main(args):
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
time_cond_proj_dim = (
|
||||
teacher_unet.config.time_cond_proj_dim
|
||||
if teacher_unet.config.time_cond_proj_dim is not None
|
||||
else args.unet_time_cond_proj_dim
|
||||
)
|
||||
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
@@ -1034,6 +1070,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1145,10 +1182,10 @@ def main(args):
|
||||
if vae.dtype != weight_dtype:
|
||||
vae.to(dtype=weight_dtype)
|
||||
|
||||
# encode pixel values with batch size of at most 32
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 32):
|
||||
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1164,9 +1201,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -61,6 +61,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -76,7 +77,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -153,6 +154,7 @@ class SDXLText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 1024,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -169,10 +171,12 @@ class SDXLText2ImageDataset:
|
||||
else:
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -318,8 +322,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -568,6 +573,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fix_crop_and_size",
|
||||
action="store_true",
|
||||
@@ -715,6 +729,26 @@ def parse_args():
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -946,10 +980,12 @@ def main(args):
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
time_cond_proj_dim = (
|
||||
teacher_unet.config.time_cond_proj_dim
|
||||
if teacher_unet.config.time_cond_proj_dim is not None
|
||||
else args.unet_time_cond_proj_dim
|
||||
)
|
||||
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
@@ -1118,6 +1154,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1242,10 +1279,10 @@ def main(args):
|
||||
else:
|
||||
pixel_values = image
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 8):
|
||||
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1262,9 +1299,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -50,13 +50,14 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -787,6 +788,12 @@ def main(args):
|
||||
logger.info("Initializing controlnet weights from unet")
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -846,9 +853,9 @@ def main(args):
|
||||
" doing mixed precision training, copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
|
||||
if unwrap_model(controlnet).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -1015,7 +1022,7 @@ def main(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
|
||||
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -1036,7 +1043,8 @@ def main(args):
|
||||
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
||||
],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@@ -1109,7 +1117,7 @@ def main(args):
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
controlnet = unwrap_model(controlnet)
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -52,13 +52,14 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -847,6 +848,11 @@ def main(args):
|
||||
logger.info("Initializing controlnet weights from unet")
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -908,9 +914,9 @@ def main(args):
|
||||
" doing mixed precision training, copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
|
||||
if unwrap_model(controlnet).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -1158,7 +1164,8 @@ def main(args):
|
||||
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
||||
],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@@ -1223,7 +1230,7 @@ def main(args):
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
controlnet = unwrap_model(controlnet)
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-5 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
|
||||
@@ -55,13 +55,14 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -129,15 +130,12 @@ def log_validation(
|
||||
if vae is not None:
|
||||
pipeline_args["vae"] = vae
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
@@ -931,11 +930,16 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
@@ -946,7 +950,7 @@ def main(args):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
if isinstance(model, type(unwrap_model(text_encoder))):
|
||||
# load transformers style into model
|
||||
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
|
||||
model.config = load_model.config
|
||||
@@ -991,15 +995,12 @@ def main(args):
|
||||
" doing mixed precision training. copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
if unwrap_model(unet).dtype != torch.float32:
|
||||
raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
|
||||
|
||||
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
|
||||
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
|
||||
f" {low_precision_error_string}"
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -1246,7 +1247,7 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
if unwrap_model(unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
@@ -1256,8 +1257,8 @@ def main(args):
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
|
||||
).sample
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
|
||||
)[0]
|
||||
|
||||
if model_pred.shape[1] == 6:
|
||||
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
|
||||
@@ -1350,9 +1351,9 @@ def main(args):
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
images = log_validation(
|
||||
text_encoder,
|
||||
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
args,
|
||||
accelerator,
|
||||
@@ -1375,14 +1376,14 @@ def main(args):
|
||||
pipeline_args = {}
|
||||
|
||||
if text_encoder is not None:
|
||||
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
|
||||
pipeline_args["text_encoder"] = unwrap_model(text_encoder)
|
||||
|
||||
if args.skip_save_text_encoder:
|
||||
pipeline_args["text_encoder"] = None
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
**pipeline_args,
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -35,7 +35,7 @@ from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
@@ -54,12 +54,19 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -647,6 +654,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
@@ -843,6 +851,11 @@ def main(args):
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -852,10 +865,12 @@ def main(args):
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -875,18 +890,41 @@ def main(args):
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.append(text_encoder_)
|
||||
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -901,6 +939,15 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.append(text_encoder)
|
||||
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
@@ -1116,7 +1163,7 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
if unwrap_model(unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
@@ -1126,8 +1173,12 @@ def main(args):
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
|
||||
).sample
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
class_labels=class_labels,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# if model predicts variance, throw away the prediction. we will only train on the
|
||||
# simplified training objective. This means that all schedulers using the fine tuned
|
||||
@@ -1213,8 +1264,8 @@ def main(args):
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1282,14 +1333,14 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
|
||||
else:
|
||||
text_encoder_state_dict = None
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
@@ -53,13 +53,19 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -780,13 +786,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -997,16 +1002,10 @@ def main(args):
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
@@ -1018,13 +1017,13 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
@@ -1049,27 +1048,46 @@ def main(args):
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -1084,6 +1102,15 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
@@ -1429,7 +1456,8 @@ def main(args):
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
@@ -1443,8 +1471,12 @@ def main(args):
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@@ -1496,6 +1528,7 @@ def main(args):
|
||||
else unet_lora_parameters
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -1617,16 +1650,16 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_two = unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
|
||||
@@ -49,10 +49,11 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -489,6 +490,11 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -845,7 +851,7 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
@@ -919,9 +925,9 @@ def main():
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -965,14 +971,14 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
|
||||
@@ -52,10 +52,11 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instru
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -531,6 +532,11 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -1044,8 +1050,12 @@ def main():
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
model_pred = unet(
|
||||
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
).sample
|
||||
concatenated_noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
@@ -1099,7 +1109,7 @@ def main():
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
### BEGIN: Perform validation every `validation_epochs` steps
|
||||
if global_step % args.validation_steps == 0 or global_step == 1:
|
||||
if global_step % args.validation_steps == 0:
|
||||
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
@@ -1115,7 +1125,7 @@ def main():
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer_1,
|
||||
@@ -1177,7 +1187,7 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -494,9 +494,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
return self.control_model.attn_processors
|
||||
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -509,7 +507,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
self.control_model.set_attn_processor(processor, _remove_lora)
|
||||
self.control_model.set_attn_processor(processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# Diffusion Model Alignment Using Direct Preference Optimization
|
||||
|
||||
This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.
|
||||
|
||||
We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:
|
||||
|
||||
* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
|
||||
* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
|
||||
|
||||
> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.
|
||||
|
||||
## SD training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_dpo.py \
|
||||
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
|
||||
--output_dir="diffusion-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--resolution=512 \
|
||||
--train_batch_size=16 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=10000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--run_validation --validation_steps=200 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## SDXL training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft
|
||||
wandb
|
||||
@@ -0,0 +1,959 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 bram-w, The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
import wandb
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
VALIDATION_PROMPTS = [
|
||||
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
|
||||
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
||||
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
|
||||
]
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not is_final_validation:
|
||||
pipeline.unet = accelerator.unwrap_model(unet)
|
||||
else:
|
||||
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = []
|
||||
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
||||
|
||||
for prompt in VALIDATION_PROMPTS:
|
||||
with context:
|
||||
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
tracker_key: [
|
||||
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Also log images without the LoRA params for comparison.
|
||||
if is_final_validation:
|
||||
pipeline.disable_lora()
|
||||
no_lora_images = [
|
||||
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in no_lora_images])
|
||||
tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test_without_lora": [
|
||||
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
|
||||
for i, image in enumerate(no_lora_images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_split_name",
|
||||
type=str,
|
||||
default="validation",
|
||||
help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_validation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to run validation inference in between training and also after training. Helps to track progress.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Run validation every X steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="diffusion-dpo-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_hflip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="DPO KL Divergence penalty.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_num_cycles",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||
)
|
||||
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp32", "fp16", "bf16"],
|
||||
help=(
|
||||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_name",
|
||||
type=str,
|
||||
default="diffusion-dpo-lora",
|
||||
help=("The name of the tracker to report results to."),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset_name is None:
|
||||
raise ValueError("Must provide a `dataset_name`.")
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def tokenize_captions(tokenizer, examples):
|
||||
max_length = tokenizer.model_max_length
|
||||
captions = []
|
||||
for caption in examples["caption"]:
|
||||
captions.append(caption)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
captions, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
|
||||
return text_inputs.input_ids
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(text_encoder, input_ids):
|
||||
text_input_ids = input_ids.to(text_encoder.device)
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Set up LoRA.
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
cache_dir=args.cache_dir,
|
||||
split=args.dataset_split_name,
|
||||
)
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),
|
||||
transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
all_pixel_values = []
|
||||
for col_name in ["jpg_0", "jpg_1"]:
|
||||
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
|
||||
pixel_values = [train_transforms(image) for image in images]
|
||||
all_pixel_values.append(pixel_values)
|
||||
|
||||
# Double on channel dim, jpg_y then jpg_w
|
||||
im_tup_iterator = zip(*all_pixel_values)
|
||||
combined_pixel_values = []
|
||||
for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
|
||||
if label_0 == 0:
|
||||
im_tup = im_tup[::-1]
|
||||
combined_im = torch.cat(im_tup, dim=0) # no batch dim
|
||||
combined_pixel_values.append(combined_im)
|
||||
examples["pixel_values"] = combined_pixel_values
|
||||
|
||||
examples["input_ids"] = tokenize_captions(tokenizer, examples)
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = train_dataset.with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
final_dict = {"pixel_values": pixel_values}
|
||||
final_dict["input_ids"] = torch.stack([example["input_ids"] for example in examples])
|
||||
return final_dict
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(args.tracker_name, config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
unet.train()
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
|
||||
|
||||
latents = []
|
||||
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(
|
||||
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
|
||||
)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
bsz = latents.shape[0] // 2
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
||||
).repeat(2)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Compute losses.
|
||||
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
||||
model_losses_w, model_losses_l = model_losses.chunk(2)
|
||||
|
||||
# For logging
|
||||
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
||||
model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
|
||||
|
||||
# Reference model predictions.
|
||||
accelerator.unwrap_model(unet).disable_adapters()
|
||||
with torch.no_grad():
|
||||
ref_preds = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
).sample.detach()
|
||||
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
# Re-enable adapters.
|
||||
accelerator.unwrap_model(unet).enable_adapters()
|
||||
|
||||
# Final loss.
|
||||
scale_term = -0.5 * args.beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * F.logsigmoid(inside_term.mean())
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.run_validation and global_step % args.validation_steps == 0:
|
||||
log_validation(
|
||||
args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
|
||||
)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None
|
||||
)
|
||||
|
||||
# Final validation?
|
||||
if args.run_validation:
|
||||
log_validation(
|
||||
args,
|
||||
unet=None,
|
||||
accelerator=accelerator,
|
||||
weight_dtype=weight_dtype,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,76 @@
|
||||
# InstructPix2Pix text-to-edit-image fine-tuning
|
||||
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
|
||||
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
|
||||
|
||||
## Training script example
|
||||
|
||||
```bash
|
||||
export MODEL_ID="timbrooks/instruct-pix2pix"
|
||||
export DATASET_ID="instruction-tuning-sd/cartoonization"
|
||||
export OUTPUT_DIR="instructPix2Pix-cartoonization"
|
||||
|
||||
accelerate launch finetune_instruct_pix2pix.py \
|
||||
--pretrained_model_name_or_path=$MODEL_ID \
|
||||
--dataset_name=$DATASET_ID \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=256 --random_flip \
|
||||
--train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
|
||||
--learning_rate=5e-05 --lr_warmup_steps=0 \
|
||||
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
|
||||
--validation_prompt="Generate a cartoonized version of the natural image" \
|
||||
--seed=42 \
|
||||
--rank=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--report_to=wandb \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Inference
|
||||
After training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.
|
||||
|
||||
```bash
|
||||
# load the base model pipeline
|
||||
pipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix")
|
||||
|
||||
# Load LoRA weights from the provided path
|
||||
output_dir = "path/to/lora_weight_directory"
|
||||
pipe_lora.unet.load_attn_procs(output_dir)
|
||||
|
||||
input_image_path = "/path/to/input_image"
|
||||
input_image = Image.open(input_image_path)
|
||||
edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
|
||||
edited_images[0].show()
|
||||
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Here is an example of using the script to train a instructpix2pix model.
|
||||
Trained on google colab T4 GPU
|
||||
|
||||
```bash
|
||||
MODEL_ID="timbrooks/instruct-pix2pix"
|
||||
DATASET_ID="instruction-tuning-sd/cartoonization"
|
||||
TRAIN_EPOCHS=100
|
||||
```
|
||||
|
||||
Below are few examples for given the input image, edit_prompt and the edited_image (output of the model)
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/edited_image_results.png?raw=true" alt="instructpix2pix-inputs" width=600/>
|
||||
</p>
|
||||
|
||||
|
||||
Here are some rough statistics about the training model using this script
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/results.png?raw=true" alt="instructpix2pix-inputs" width=600/>
|
||||
</p>
|
||||
|
||||
## References
|
||||
|
||||
* InstructPix2Pix - https://github.com/timothybrooks/instruct-pix2pix
|
||||
* Dataset and example training script - https://huggingface.co/blog/instruction-tuning-sd
|
||||
* For more information about the project - https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
# Multi Subject Dreambooth for Inpainting Models
|
||||
|
||||
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
|
||||
|
||||
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
|
||||
|
||||
**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
|
||||
|
||||
## 1. Data Collection: Make Prompt-Image-Mask Pairs
|
||||
|
||||
Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
|
||||
|
||||
The notebook can be found here: [](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
|
||||
|
||||
The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
|
||||
|
||||

|
||||
|
||||
You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
|
||||
|
||||
## 2. Train Multi Subject Dreambooth for Inpainting
|
||||
|
||||
### 2.1. Setting The Training Configuration
|
||||
|
||||
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
export DATASET_1="gzguevara/mr_potato_head_masked"
|
||||
export DATASET_2="gzguevara/cat_toy_masked"
|
||||
... # Further paths to 🤗 datasets
|
||||
```
|
||||
|
||||
### 2.2. Launching The Training Script
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=3e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb
|
||||
```
|
||||
|
||||
### 2.3. Fine-tune text encoder with the UNet.
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
|
||||
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
|
||||
|
||||
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=2e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
|
||||
## 3. Results
|
||||
|
||||
A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=10 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
Here you can see the target objects on my desk and next to my plant:
|
||||
|
||||

|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets>=2.16.0
|
||||
wandb>=0.16.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
+661
@@ -0,0 +1,661 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_from",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=("Start to checkpoint from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_from",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Start to validate from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_project_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The w&b name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
datasets_paths,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.datasets_paths = (datasets_paths,)
|
||||
self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
|
||||
self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
|
||||
self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
|
||||
|
||||
self.image_normalize = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def set_image(self, img, switch):
|
||||
if img.mode not in ["RGB", "L"]:
|
||||
img = img.convert("RGB")
|
||||
|
||||
if switch:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
img = img.resize((512, 512), Image.BILINEAR)
|
||||
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.train_data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Lettings
|
||||
example = {}
|
||||
img_idx = index % len(self.train_data)
|
||||
switch = random.choice([True, False])
|
||||
|
||||
# Load image
|
||||
image = self.set_image(self.train_data[img_idx]["image"], switch)
|
||||
|
||||
# Normalize image
|
||||
image_norm = self.image_normalize(image)
|
||||
|
||||
# Tokenise prompt
|
||||
tokenized_prompt = self.tokenizer(
|
||||
self.train_data[img_idx]["prompt"],
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
# Load masks for image
|
||||
masks = [
|
||||
self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
|
||||
]
|
||||
|
||||
# Build example
|
||||
example["PIL_image"] = image
|
||||
example["instance_image"] = image_norm
|
||||
example["instance_prompt_id"] = tokenized_prompt
|
||||
example["instance_masks"] = masks
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def weighted_mask(masks):
|
||||
# Convert each mask to a NumPy array and ensure it's binary
|
||||
mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
|
||||
|
||||
# Generate random weights and apply them to each mask
|
||||
weights = [random.random() for _ in masks]
|
||||
weights = [weight / sum(weights) for weight in weights]
|
||||
weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
|
||||
|
||||
# Sum the weighted masks
|
||||
summed_mask = np.sum(weighted_masks, axis=0)
|
||||
|
||||
# Apply a threshold to create the final mask
|
||||
threshold = 0.5 # This threshold can be adjusted
|
||||
result_mask = summed_mask >= threshold
|
||||
|
||||
# Convert the result back to a PIL image
|
||||
return Image.fromarray(result_mask.astype(np.uint8) * 255)
|
||||
|
||||
|
||||
def collate_fn(examples, tokenizer):
|
||||
input_ids = [example["instance_prompt_id"] for example in examples]
|
||||
pixel_values = [example["instance_image"] for example in examples]
|
||||
|
||||
masks, masked_images = [], []
|
||||
|
||||
for example in examples:
|
||||
# generate a random mask
|
||||
mask = weighted_mask(example["instance_masks"])
|
||||
|
||||
# prepare mask and masked image
|
||||
mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
|
||||
|
||||
masks.append(mask)
|
||||
masked_images.append(masked_image)
|
||||
|
||||
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
||||
masks = torch.stack(masks)
|
||||
masked_images = torch.stack(masked_images)
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
|
||||
# update pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
pipeline.unet = accelerator.unwrap_model(unet)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
|
||||
|
||||
|
||||
def checkpoint(args, global_step, accelerator):
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit,
|
||||
project_dir=args.output_dir,
|
||||
logging_dir=Path(args.output_dir, args.logging_dir),
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
project_config=project_config,
|
||||
log_with="wandb" if args.report_to_wandb else None,
|
||||
)
|
||||
|
||||
if args.report_to_wandb and not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
# Load the tokenizer & models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
).requires_grad_(args.train_text_encoder)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
tokenizer=tokenizer,
|
||||
datasets_paths=args.instance_data_dir,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, tokenizer),
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
# Afterwards we calculate our number of training epochs
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = vars(copy.deepcopy(args))
|
||||
accelerator.init_trackers(args.validation_project_name, config=tracker_config)
|
||||
|
||||
# create validation pipeline (note: unet and vae are loaded again in float32)
|
||||
val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
torch_dtype=weight_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
val_pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# prepare validation dataset
|
||||
val_pairs = [
|
||||
{
|
||||
"image": example["image"],
|
||||
"mask_image": mask,
|
||||
"prompt": example["prompt"],
|
||||
}
|
||||
for example in train_dataset.test_data
|
||||
for mask in [example[key] for key in example if "mask" in key]
|
||||
]
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
print()
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
masks = batch["masks"]
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# concatenate the noised latents with the mask and the masked latents
|
||||
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if (
|
||||
global_step % args.validation_steps == 0
|
||||
and global_step >= args.validation_from
|
||||
and args.report_to_wandb
|
||||
):
|
||||
log_validation(
|
||||
val_pipeline,
|
||||
text_encoder,
|
||||
unet,
|
||||
val_pairs,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
|
||||
checkpoint(
|
||||
args,
|
||||
global_step,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
# Step logging
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Terminate training
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,4 +6,4 @@ torch==2.0.1
|
||||
torchvision>=0.16
|
||||
ftfy==6.1.1
|
||||
tensorboard==2.14.0
|
||||
Jinja2==3.1.2
|
||||
Jinja2==3.1.3
|
||||
|
||||
@@ -50,6 +50,7 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
@@ -58,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -926,6 +927,11 @@ def main(args):
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
@@ -935,9 +941,9 @@ def main(args):
|
||||
" doing mixed precision training, copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(t2iadapter).dtype != torch.float32:
|
||||
if unwrap_model(t2iadapter).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
|
||||
f"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -1198,7 +1204,8 @@ def main(args):
|
||||
encoder_hidden_states=batch["prompt_ids"],
|
||||
added_cond_kwargs=batch["unet_added_conditions"],
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Denoise the latents
|
||||
denoised_latents = model_pred * (-sigmas) + noisy_latents
|
||||
@@ -1279,7 +1286,7 @@ def main(args):
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
t2iadapter = accelerator.unwrap_model(t2iadapter)
|
||||
t2iadapter = unwrap_model(t2iadapter)
|
||||
t2iadapter.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
@@ -53,7 +54,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -833,6 +834,12 @@ def main():
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Function for unwrapping if model was compiled with `torch.compile`.
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -912,7 +919,7 @@ def main():
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if args.prediction_type is not None:
|
||||
@@ -927,7 +934,7 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
@@ -1023,7 +1030,7 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -43,13 +43,14 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -466,10 +467,8 @@ def main():
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
@@ -486,6 +485,9 @@ def main():
|
||||
|
||||
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
@@ -595,6 +597,11 @@ def main():
|
||||
]
|
||||
)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["pixel_values"] = [train_transforms(image) for image in images]
|
||||
@@ -728,7 +735,7 @@ def main():
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if args.prediction_type is not None:
|
||||
@@ -743,7 +750,7 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
@@ -808,8 +815,10 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
@@ -834,7 +843,7 @@ def main():
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -875,8 +884,8 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
|
||||
@@ -51,13 +51,14 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -634,11 +634,13 @@ def main(args):
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
@@ -650,12 +652,16 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -677,11 +683,11 @@ def main(args):
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
@@ -702,6 +708,12 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
@@ -824,6 +836,9 @@ def main(args):
|
||||
for image in images:
|
||||
original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
@@ -831,10 +846,6 @@ def main(args):
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
x1 = image.width - x1
|
||||
image = train_flip(image)
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
@@ -1024,8 +1035,12 @@ def main(args):
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if args.prediction_type is not None:
|
||||
@@ -1118,9 +1133,9 @@ def main(args):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1159,15 +1174,15 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_two = unwrap_model(text_encoder_two)
|
||||
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -44,20 +44,16 @@ from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -842,6 +839,9 @@ def main(args):
|
||||
for image in images:
|
||||
original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
@@ -849,10 +849,6 @@ def main(args):
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
x1 = image.width - x1
|
||||
image = train_flip(image)
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
@@ -955,6 +951,12 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
|
||||
|
||||
# Function for unwraping if torch.compile() was used in accelerate.
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1054,8 +1056,12 @@ def main(args):
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if args.prediction_type is not None:
|
||||
@@ -1206,7 +1212,7 @@ def main(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
|
||||
@@ -60,6 +60,8 @@ Now we can launch the training using:
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
## Textual Inversion fine-tuning example for SDXL
|
||||
|
||||
```
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
accelerate launch textual_inversion_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<cat-toy>" \
|
||||
--initializer_token="toy" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=768 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=500 \
|
||||
--learning_rate=5.0e-04 \
|
||||
--scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--save_as_full_pipeline \
|
||||
--output_dir="./textual_inversion_cat_sdxl"
|
||||
```
|
||||
|
||||
For now, only training of the first text encoder is supported.
|
||||
@@ -0,0 +1,152 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class TextualInversionSdxl(ExamplesTestsAccelerate):
|
||||
def test_textual_inversion_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/textual_inversion/textual_inversion_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
|
||||
--train_data_dir docs/source/en/imgs
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
|
||||
|
||||
def test_textual_inversion_sdxl_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/textual_inversion/textual_inversion_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
|
||||
--train_data_dir docs/source/en/imgs
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 3
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-3"},
|
||||
)
|
||||
|
||||
def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/textual_inversion/textual_inversion_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
|
||||
--train_data_dir docs/source/en/imgs
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-1", "checkpoint-2"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/textual_inversion/textual_inversion_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
|
||||
--train_data_dir docs/source/en/imgs
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint=checkpoint-2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-3"},
|
||||
)
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -341,7 +341,7 @@ def parse_args():
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
"and Nvidia Ampere GPU or Intel Gen 4 Xeon (and later) ."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
|
||||
import OmegaConf
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
||||
|
||||
|
||||
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
config = OmegaConf.load(config_path)
|
||||
config = yaml.safe_load(config_path)
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
@@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
|
||||
|
||||
vqvae_init_args = config.model.params.first_stage_config.params
|
||||
unet_init_args = config.model.params.unet_config.params
|
||||
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
|
||||
unet_init_args = config["model"]["params"]["unet_config"]["params"]
|
||||
|
||||
vqvae = VQModel(**vqvae_init_args).eval()
|
||||
vqvae.load_state_dict(first_stage_dict)
|
||||
@@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
unet.load_state_dict(unet_state_dict)
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
timesteps=config.model.params.timesteps,
|
||||
timesteps=config["model"]["params"]["timesteps"],
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=config.model.params.linear_start,
|
||||
beta_end=config.model.params.linear_end,
|
||||
beta_start=config["model"]["params"]["linear_start"],
|
||||
beta_end=config["model"]["params"]["linear_end"],
|
||||
clip_sample=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
|
||||
# This means that you can input your diffusers-trained LoRAs and
|
||||
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
|
||||
|
||||
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
|
||||
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
|
||||
# and run the script:
|
||||
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
|
||||
# now you can use corgy.safetensors in your WebUI of choice!
|
||||
|
||||
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
|
||||
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
|
||||
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
|
||||
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
|
||||
# Canonical diffusers training scripts:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
|
||||
|
||||
|
||||
def convert_and_save(input_lora, output_lora=None):
|
||||
if output_lora is None:
|
||||
base_name = os.path.splitext(input_lora)[0]
|
||||
output_lora = f"{base_name}_webui.safetensors"
|
||||
|
||||
diffusers_state_dict = load_file(input_lora)
|
||||
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, output_lora)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
|
||||
parser.add_argument(
|
||||
"input_lora",
|
||||
type=str,
|
||||
help="Path to the input LoRA model file in the diffusers format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_lora",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_and_save(args.input_lora, args.output_lora)
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import re
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import (
|
||||
CLIPProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -28,8 +29,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
textenc_conversion_map,
|
||||
textenc_pattern,
|
||||
)
|
||||
from diffusers.utils import is_omegaconf_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
@@ -370,52 +369,52 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa
|
||||
|
||||
|
||||
def create_vae_config(original_config, image_size: int):
|
||||
vae_params = original_config.autoencoder.params.ddconfig
|
||||
_ = original_config.autoencoder.params.embed_dim
|
||||
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
|
||||
_ = original_config["autoencoder"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_unet_config(original_config, image_size: int, attention_type):
|
||||
unet_params = original_config.model.params
|
||||
vae_params = original_config.autoencoder.params.ddconfig
|
||||
unet_params = original_config["model"]["params"]
|
||||
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
if head_dim is None:
|
||||
@@ -423,11 +422,11 @@ def create_unet_config(original_config, image_size: int, attention_type):
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"cross_attention_dim": unet_params.context_dim,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": unet_params["context_dim"],
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"attention_type": attention_type,
|
||||
@@ -445,11 +444,6 @@ def convert_gligen_to_diffusers(
|
||||
num_in_channels: int = None,
|
||||
device: str = None,
|
||||
):
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@@ -461,14 +455,14 @@ def convert_gligen_to_diffusers(
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
num_train_timesteps = original_config.diffusion.params.timesteps
|
||||
beta_start = original_config.diffusion.params.linear_start
|
||||
beta_end = original_config.diffusion.params.linear_end
|
||||
num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
|
||||
beta_start = original_config["diffusion"]["params"]["linear_start"]
|
||||
beta_end = original_config["diffusion"]["params"]["linear_end"]
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
|
||||
+46
-53
@@ -4,6 +4,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from torch.nn import functional as F
|
||||
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -11,14 +12,6 @@ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet
|
||||
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`."
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -143,8 +136,8 @@ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safet
|
||||
|
||||
|
||||
def get_stage_1_unet(unet_config, unet_checkpoint_path):
|
||||
original_unet_config = OmegaConf.load(unet_config)
|
||||
original_unet_config = original_unet_config.params
|
||||
original_unet_config = yaml.safe_load(unet_config)
|
||||
original_unet_config = original_unet_config["params"]
|
||||
|
||||
unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
|
||||
|
||||
@@ -215,11 +208,11 @@ def convert_safety_checker(p_head_path, w_head_path):
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
||||
attention_resolutions = parse_list(original_unet_config.attention_resolutions)
|
||||
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
|
||||
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
||||
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
||||
|
||||
channel_mult = parse_list(original_unet_config.channel_mult)
|
||||
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
|
||||
channel_mult = parse_list(original_unet_config["channel_mult"])
|
||||
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
@@ -227,7 +220,7 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
||||
for i in range(len(block_out_channels)):
|
||||
if resolution in attention_resolutions:
|
||||
block_type = "SimpleCrossAttnDownBlock2D"
|
||||
elif original_unet_config.resblock_updown:
|
||||
elif original_unet_config["resblock_updown"]:
|
||||
block_type = "ResnetDownsampleBlock2D"
|
||||
else:
|
||||
block_type = "DownBlock2D"
|
||||
@@ -241,17 +234,17 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
||||
for i in range(len(block_out_channels)):
|
||||
if resolution in attention_resolutions:
|
||||
block_type = "SimpleCrossAttnUpBlock2D"
|
||||
elif original_unet_config.resblock_updown:
|
||||
elif original_unet_config["resblock_updown"]:
|
||||
block_type = "ResnetUpsampleBlock2D"
|
||||
else:
|
||||
block_type = "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
head_dim = original_unet_config.num_head_channels
|
||||
head_dim = original_unet_config["num_head_channels"]
|
||||
|
||||
use_linear_projection = (
|
||||
original_unet_config.use_linear_in_transformer
|
||||
original_unet_config["use_linear_in_transformer"]
|
||||
if "use_linear_in_transformer" in original_unet_config
|
||||
else False
|
||||
)
|
||||
@@ -264,27 +257,27 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
||||
|
||||
if class_embed_type is None:
|
||||
if "num_classes" in original_unet_config:
|
||||
if original_unet_config.num_classes == "sequential":
|
||||
if original_unet_config["num_classes"] == "sequential":
|
||||
class_embed_type = "projection"
|
||||
assert "adm_in_channels" in original_unet_config
|
||||
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
|
||||
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
|
||||
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
||||
)
|
||||
|
||||
config = {
|
||||
"sample_size": original_unet_config.image_size,
|
||||
"in_channels": original_unet_config.in_channels,
|
||||
"sample_size": original_unet_config["image_size"],
|
||||
"in_channels": original_unet_config["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": original_unet_config.num_res_blocks,
|
||||
"cross_attention_dim": original_unet_config.encoder_channels,
|
||||
"layers_per_block": original_unet_config["num_res_blocks"],
|
||||
"cross_attention_dim": original_unet_config["encoder_channels"],
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
"out_channels": original_unet_config.out_channels,
|
||||
"out_channels": original_unet_config["out_channels"],
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"upcast_attention": False, # TODO: guessing
|
||||
"cross_attention_norm": "group_norm",
|
||||
@@ -293,11 +286,11 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
||||
"act_fn": "gelu",
|
||||
}
|
||||
|
||||
if original_unet_config.use_scale_shift_norm:
|
||||
if original_unet_config["use_scale_shift_norm"]:
|
||||
config["resnet_time_scale_shift"] = "scale_shift"
|
||||
|
||||
if "encoder_dim" in original_unet_config:
|
||||
config["encoder_hid_dim"] = original_unet_config.encoder_dim
|
||||
config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
||||
|
||||
return config
|
||||
|
||||
@@ -725,15 +718,15 @@ def parse_list(value):
|
||||
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
|
||||
orig_path = unet_checkpoint_path
|
||||
|
||||
original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml"))
|
||||
original_unet_config = original_unet_config.params
|
||||
original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
|
||||
original_unet_config = original_unet_config["params"]
|
||||
|
||||
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
|
||||
unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int(
|
||||
original_unet_config.channel_mult.split(",")[-1]
|
||||
unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
|
||||
original_unet_config["channel_mult"].split(",")[-1]
|
||||
)
|
||||
if original_unet_config.encoder_dim != original_unet_config.encoder_channels:
|
||||
unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim
|
||||
if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
|
||||
unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
||||
unet_diffusers_config["class_embed_type"] = "timestep"
|
||||
unet_diffusers_config["addition_embed_type"] = "text"
|
||||
|
||||
@@ -742,16 +735,16 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
|
||||
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
|
||||
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
|
||||
unet_diffusers_config["only_cross_attention"] = (
|
||||
bool(original_unet_config.disable_self_attentions)
|
||||
bool(original_unet_config["disable_self_attentions"])
|
||||
if (
|
||||
"disable_self_attentions" in original_unet_config
|
||||
and isinstance(original_unet_config.disable_self_attentions, int)
|
||||
and isinstance(original_unet_config["disable_self_attentions"], int)
|
||||
)
|
||||
else True
|
||||
)
|
||||
|
||||
if sample_size is None:
|
||||
unet_diffusers_config["sample_size"] = original_unet_config.image_size
|
||||
unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
|
||||
else:
|
||||
# The second upscaler unet's sample size is incorrectly specified
|
||||
# in the config and is instead hardcoded in source
|
||||
@@ -783,11 +776,11 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
|
||||
|
||||
|
||||
def superres_create_unet_diffusers_config(original_unet_config):
|
||||
attention_resolutions = parse_list(original_unet_config.attention_resolutions)
|
||||
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
|
||||
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
||||
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
||||
|
||||
channel_mult = parse_list(original_unet_config.channel_mult)
|
||||
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
|
||||
channel_mult = parse_list(original_unet_config["channel_mult"])
|
||||
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
@@ -795,7 +788,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
|
||||
for i in range(len(block_out_channels)):
|
||||
if resolution in attention_resolutions:
|
||||
block_type = "SimpleCrossAttnDownBlock2D"
|
||||
elif original_unet_config.resblock_updown:
|
||||
elif original_unet_config["resblock_updown"]:
|
||||
block_type = "ResnetDownsampleBlock2D"
|
||||
else:
|
||||
block_type = "DownBlock2D"
|
||||
@@ -809,16 +802,16 @@ def superres_create_unet_diffusers_config(original_unet_config):
|
||||
for i in range(len(block_out_channels)):
|
||||
if resolution in attention_resolutions:
|
||||
block_type = "SimpleCrossAttnUpBlock2D"
|
||||
elif original_unet_config.resblock_updown:
|
||||
elif original_unet_config["resblock_updown"]:
|
||||
block_type = "ResnetUpsampleBlock2D"
|
||||
else:
|
||||
block_type = "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
head_dim = original_unet_config.num_head_channels
|
||||
head_dim = original_unet_config["num_head_channels"]
|
||||
use_linear_projection = (
|
||||
original_unet_config.use_linear_in_transformer
|
||||
original_unet_config["use_linear_in_transformer"]
|
||||
if "use_linear_in_transformer" in original_unet_config
|
||||
else False
|
||||
)
|
||||
@@ -831,26 +824,26 @@ def superres_create_unet_diffusers_config(original_unet_config):
|
||||
projection_class_embeddings_input_dim = None
|
||||
|
||||
if "num_classes" in original_unet_config:
|
||||
if original_unet_config.num_classes == "sequential":
|
||||
if original_unet_config["num_classes"] == "sequential":
|
||||
class_embed_type = "projection"
|
||||
assert "adm_in_channels" in original_unet_config
|
||||
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
|
||||
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
|
||||
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
||||
)
|
||||
|
||||
config = {
|
||||
"in_channels": original_unet_config.in_channels,
|
||||
"in_channels": original_unet_config["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": tuple(original_unet_config.num_res_blocks),
|
||||
"cross_attention_dim": original_unet_config.encoder_channels,
|
||||
"layers_per_block": tuple(original_unet_config["num_res_blocks"]),
|
||||
"cross_attention_dim": original_unet_config["encoder_channels"],
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
"out_channels": original_unet_config.out_channels,
|
||||
"out_channels": original_unet_config["out_channels"],
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"upcast_attention": False, # TODO: guessing
|
||||
"cross_attention_norm": "group_norm",
|
||||
@@ -858,7 +851,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
|
||||
"act_fn": "gelu",
|
||||
}
|
||||
|
||||
if original_unet_config.use_scale_shift_norm:
|
||||
if original_unet_config["use_scale_shift_norm"]:
|
||||
config["resnet_time_scale_shift"] = "scale_shift"
|
||||
|
||||
return config
|
||||
|
||||
@@ -19,6 +19,7 @@ import re
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
@@ -45,7 +46,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_omegaconf_available, is_safetensors_available
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
@@ -212,41 +213,41 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a UNet config for diffusers based on the config of the original AudioLDM2 model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
cross_attention_dim = list(unet_params.context_dim) if "context_dim" in unet_params else block_out_channels
|
||||
cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels
|
||||
if len(cross_attention_dim) > 1:
|
||||
# require two or more cross-attention layers per-block, each of different dimension
|
||||
cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))]
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"out_channels": unet_params.out_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"out_channels": unet_params["out_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"transformer_layers_per_block": unet_params.transformer_depth,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"transformer_layers_per_block": unet_params["transformer_depth"],
|
||||
"cross_attention_dim": tuple(cross_attention_dim),
|
||||
}
|
||||
|
||||
@@ -259,24 +260,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original
|
||||
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": float(scaling_factor),
|
||||
}
|
||||
return config
|
||||
@@ -285,9 +286,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
||||
beta_start=original_config["model"]["params"]["linear_start"],
|
||||
beta_end=original_config["model"]["params"]["linear_end"],
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
return schedular
|
||||
@@ -692,17 +693,17 @@ def create_transformers_vocoder_config(original_config):
|
||||
"""
|
||||
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
||||
"""
|
||||
vocoder_params = original_config.model.params.vocoder_config.params
|
||||
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
||||
|
||||
config = {
|
||||
"model_in_dim": vocoder_params.num_mels,
|
||||
"sampling_rate": vocoder_params.sampling_rate,
|
||||
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
|
||||
"upsample_rates": list(vocoder_params.upsample_rates),
|
||||
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
|
||||
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
|
||||
"model_in_dim": vocoder_params["num_mels"],
|
||||
"sampling_rate": vocoder_params["sampling_rate"],
|
||||
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
||||
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
||||
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
||||
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
||||
"resblock_dilation_sizes": [
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
||||
],
|
||||
"normalize_before": False,
|
||||
}
|
||||
@@ -876,11 +877,6 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
|
||||
return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
@@ -903,9 +899,8 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
|
||||
|
||||
if original_config_file is None:
|
||||
original_config = DEFAULT_CONFIG
|
||||
original_config = OmegaConf.create(original_config)
|
||||
else:
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
if image_size is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size
|
||||
@@ -926,9 +921,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
||||
beta_start = original_config["model"]["params"]["linear_start"]
|
||||
beta_end = original_config["model"]["params"]["linear_end"]
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
@@ -1026,9 +1021,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
|
||||
# Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model
|
||||
gpt2_config = GPT2Config.from_pretrained("gpt2")
|
||||
gpt2_model = GPT2Model(gpt2_config)
|
||||
gpt2_model.config.max_new_tokens = (
|
||||
original_config.model.params.cond_stage_config.crossattn_audiomae_generated.params.sequence_gen_length
|
||||
)
|
||||
gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][
|
||||
"crossattn_audiomae_generated"
|
||||
]["params"]["sequence_gen_length"]
|
||||
|
||||
converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.")
|
||||
gpt2_model.load_state_dict(converted_gpt2_checkpoint)
|
||||
|
||||
@@ -18,6 +18,7 @@ import argparse
|
||||
import re
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
ClapTextConfig,
|
||||
@@ -38,8 +39,6 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import is_omegaconf_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
||||
@@ -215,45 +214,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a UNet config for diffusers based on the config of the original AudioLDM model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
cross_attention_dim = (
|
||||
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
|
||||
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
|
||||
)
|
||||
|
||||
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
|
||||
projection_class_embeddings_input_dim = (
|
||||
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
|
||||
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
|
||||
)
|
||||
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
|
||||
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"out_channels": unet_params.out_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"out_channels": unet_params["out_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": cross_attention_dim,
|
||||
"class_embed_type": class_embed_type,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
@@ -269,24 +268,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
|
||||
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": float(scaling_factor),
|
||||
}
|
||||
return config
|
||||
@@ -295,9 +294,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
||||
beta_start=original_config["model"]["params"]["linear_start"],
|
||||
beta_end=original_config["model"]["params"]["linear_end"],
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
return schedular
|
||||
@@ -668,17 +667,17 @@ def create_transformers_vocoder_config(original_config):
|
||||
"""
|
||||
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
||||
"""
|
||||
vocoder_params = original_config.model.params.vocoder_config.params
|
||||
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
||||
|
||||
config = {
|
||||
"model_in_dim": vocoder_params.num_mels,
|
||||
"sampling_rate": vocoder_params.sampling_rate,
|
||||
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
|
||||
"upsample_rates": list(vocoder_params.upsample_rates),
|
||||
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
|
||||
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
|
||||
"model_in_dim": vocoder_params["num_mels"],
|
||||
"sampling_rate": vocoder_params["sampling_rate"],
|
||||
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
||||
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
||||
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
||||
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
||||
"resblock_dilation_sizes": [
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
||||
],
|
||||
"normalize_before": False,
|
||||
}
|
||||
@@ -818,11 +817,6 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
@@ -842,9 +836,8 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
|
||||
if original_config_file is None:
|
||||
original_config = DEFAULT_CONFIG
|
||||
original_config = OmegaConf.create(original_config)
|
||||
else:
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
@@ -868,9 +861,9 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
||||
beta_start = original_config["model"]["params"]["linear_start"]
|
||||
beta_end = original_config["model"]["params"]["linear_end"]
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
|
||||
@@ -18,6 +18,7 @@ import argparse
|
||||
import re
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
@@ -39,8 +40,6 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import is_omegaconf_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
||||
@@ -212,45 +211,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a UNet config for diffusers based on the config of the original MusicLDM model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
cross_attention_dim = (
|
||||
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
|
||||
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
|
||||
)
|
||||
|
||||
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
|
||||
projection_class_embeddings_input_dim = (
|
||||
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
|
||||
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
|
||||
)
|
||||
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
|
||||
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"out_channels": unet_params.out_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"out_channels": unet_params["out_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": cross_attention_dim,
|
||||
"class_embed_type": class_embed_type,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
@@ -266,24 +265,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original
|
||||
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
|
||||
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": float(scaling_factor),
|
||||
}
|
||||
return config
|
||||
@@ -292,9 +291,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
||||
beta_start=original_config["model"]["params"]["linear_start"],
|
||||
beta_end=original_config["model"]["params"]["linear_end"],
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
return schedular
|
||||
@@ -674,17 +673,17 @@ def create_transformers_vocoder_config(original_config):
|
||||
"""
|
||||
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
||||
"""
|
||||
vocoder_params = original_config.model.params.vocoder_config.params
|
||||
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
||||
|
||||
config = {
|
||||
"model_in_dim": vocoder_params.num_mels,
|
||||
"sampling_rate": vocoder_params.sampling_rate,
|
||||
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
|
||||
"upsample_rates": list(vocoder_params.upsample_rates),
|
||||
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
|
||||
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
|
||||
"model_in_dim": vocoder_params["num_mels"],
|
||||
"sampling_rate": vocoder_params["sampling_rate"],
|
||||
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
||||
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
||||
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
||||
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
||||
"resblock_dilation_sizes": [
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
|
||||
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
||||
],
|
||||
"normalize_before": False,
|
||||
}
|
||||
@@ -823,12 +822,6 @@ def load_pipeline_from_original_MusicLDM_ckpt(
|
||||
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
||||
return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
@@ -848,9 +841,8 @@ def load_pipeline_from_original_MusicLDM_ckpt(
|
||||
|
||||
if original_config_file is None:
|
||||
original_config = DEFAULT_CONFIG
|
||||
original_config = OmegaConf.create(original_config)
|
||||
else:
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
@@ -874,9 +866,9 @@ def load_pipeline_from_original_MusicLDM_ckpt(
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
||||
beta_start = original_config["model"]["params"]["linear_start"]
|
||||
beta_end = original_config["model"]["params"]["linear_end"]
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
|
||||
@@ -3,7 +3,7 @@ import io
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
import yaml
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
@@ -126,7 +126,7 @@ def vae_pt_to_vae_diffuser(
|
||||
)
|
||||
io_obj = io.BytesIO(r.content)
|
||||
|
||||
original_config = OmegaConf.load(io_obj)
|
||||
original_config = yaml.safe_load(io_obj)
|
||||
image_size = 512
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if checkpoint_path.endswith("safetensors"):
|
||||
|
||||
@@ -45,51 +45,45 @@ from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionSchedu
|
||||
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
|
||||
" OmegaConf`."
|
||||
)
|
||||
|
||||
# vqvae model
|
||||
|
||||
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_config["target"] in PORTED_VQVAES
|
||||
), f"{original_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config.params
|
||||
original_config = original_config["params"]
|
||||
|
||||
original_encoder_config = original_config.encoder_config.params
|
||||
original_decoder_config = original_config.decoder_config.params
|
||||
original_encoder_config = original_config["encoder_config"]["params"]
|
||||
original_decoder_config = original_config["decoder_config"]["params"]
|
||||
|
||||
in_channels = original_encoder_config.in_channels
|
||||
out_channels = original_decoder_config.out_ch
|
||||
in_channels = original_encoder_config["in_channels"]
|
||||
out_channels = original_decoder_config["out_ch"]
|
||||
|
||||
down_block_types = get_down_block_types(original_encoder_config)
|
||||
up_block_types = get_up_block_types(original_decoder_config)
|
||||
|
||||
assert original_encoder_config.ch == original_decoder_config.ch
|
||||
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
|
||||
assert original_encoder_config["ch"] == original_decoder_config["ch"]
|
||||
assert original_encoder_config["ch_mult"] == original_decoder_config["ch_mult"]
|
||||
block_out_channels = tuple(
|
||||
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
|
||||
[original_encoder_config["ch"] * a_ch_mult for a_ch_mult in original_encoder_config["ch_mult"]]
|
||||
)
|
||||
|
||||
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
|
||||
layers_per_block = original_encoder_config.num_res_blocks
|
||||
assert original_encoder_config["num_res_blocks"] == original_decoder_config["num_res_blocks"]
|
||||
layers_per_block = original_encoder_config["num_res_blocks"]
|
||||
|
||||
assert original_encoder_config.z_channels == original_decoder_config.z_channels
|
||||
latent_channels = original_encoder_config.z_channels
|
||||
assert original_encoder_config["z_channels"] == original_decoder_config["z_channels"]
|
||||
latent_channels = original_encoder_config["z_channels"]
|
||||
|
||||
num_vq_embeddings = original_config.n_embed
|
||||
num_vq_embeddings = original_config["n_embed"]
|
||||
|
||||
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
|
||||
norm_num_groups = 32
|
||||
|
||||
e_dim = original_config.embed_dim
|
||||
e_dim = original_config["embed_dim"]
|
||||
|
||||
model = VQModel(
|
||||
in_channels=in_channels,
|
||||
@@ -108,9 +102,9 @@ def vqvae_model_from_original_config(original_config):
|
||||
|
||||
|
||||
def get_down_block_types(original_encoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_encoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_encoder_config.resolution)
|
||||
attn_resolutions = coerce_attn_resolutions(original_encoder_config["attn_resolutions"])
|
||||
num_resolutions = len(original_encoder_config["ch_mult"])
|
||||
resolution = coerce_resolution(original_encoder_config["resolution"])
|
||||
|
||||
curr_res = resolution
|
||||
down_block_types = []
|
||||
@@ -129,9 +123,9 @@ def get_down_block_types(original_encoder_config):
|
||||
|
||||
|
||||
def get_up_block_types(original_decoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_decoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_decoder_config.resolution)
|
||||
attn_resolutions = coerce_attn_resolutions(original_decoder_config["attn_resolutions"])
|
||||
num_resolutions = len(original_decoder_config["ch_mult"])
|
||||
resolution = coerce_resolution(original_decoder_config["resolution"])
|
||||
|
||||
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
|
||||
up_block_types = []
|
||||
@@ -150,7 +144,7 @@ def get_up_block_types(original_decoder_config):
|
||||
|
||||
|
||||
def coerce_attn_resolutions(attn_resolutions):
|
||||
attn_resolutions = OmegaConf.to_object(attn_resolutions)
|
||||
attn_resolutions = list(attn_resolutions)
|
||||
attn_resolutions_ = []
|
||||
for ar in attn_resolutions:
|
||||
if isinstance(ar, (list, tuple)):
|
||||
@@ -161,7 +155,6 @@ def coerce_attn_resolutions(attn_resolutions):
|
||||
|
||||
|
||||
def coerce_resolution(resolution):
|
||||
resolution = OmegaConf.to_object(resolution)
|
||||
if isinstance(resolution, int):
|
||||
resolution = [resolution, resolution] # H, W
|
||||
elif isinstance(resolution, (tuple, list)):
|
||||
@@ -472,18 +465,18 @@ def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config.target in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
|
||||
original_diffusion_config["target"] in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config.target in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config.target} has not yet been ported to diffusers."
|
||||
original_transformer_config["target"] in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
|
||||
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config.params
|
||||
original_transformer_config = original_transformer_config.params
|
||||
original_content_embedding_config = original_content_embedding_config.params
|
||||
original_diffusion_config = original_diffusion_config["params"]
|
||||
original_transformer_config = original_transformer_config["params"]
|
||||
original_content_embedding_config = original_content_embedding_config["params"]
|
||||
|
||||
inner_dim = original_transformer_config["n_embd"]
|
||||
|
||||
@@ -689,13 +682,11 @@ def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_fee
|
||||
|
||||
def read_config_file(filename):
|
||||
# The yaml file contains annotations that certain values should
|
||||
# loaded as tuples. By default, OmegaConf will panic when reading
|
||||
# these. Instead, we can manually read the yaml with the FullLoader and then
|
||||
# construct the OmegaConf object.
|
||||
# loaded as tuples.
|
||||
with open(filename) as f:
|
||||
original_config = yaml.load(f, FullLoader)
|
||||
|
||||
return OmegaConf.create(original_config)
|
||||
return original_config
|
||||
|
||||
|
||||
# We take separate arguments for the vqvae because the ITHQ vqvae config file
|
||||
@@ -792,9 +783,9 @@ if __name__ == "__main__":
|
||||
|
||||
original_config = read_config_file(args.original_config_file).model
|
||||
|
||||
diffusion_config = original_config.params.diffusion_config
|
||||
transformer_config = original_config.params.diffusion_config.params.transformer_config
|
||||
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
|
||||
diffusion_config = original_config["params"]["diffusion_config"]
|
||||
transformer_config = original_config["params"]["diffusion_config"]["params"]["transformer_config"]
|
||||
content_embedding_config = original_config["params"]["diffusion_config"]["params"]["content_emb_config"]
|
||||
|
||||
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
@@ -831,7 +822,7 @@ if __name__ == "__main__":
|
||||
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
|
||||
# model, so we pull them off the checkpoint before the checkpoint is deleted.
|
||||
|
||||
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
|
||||
learnable_classifier_free_sampling_embeddings = diffusion_config["params"].learnable_cf
|
||||
|
||||
if learnable_classifier_free_sampling_embeddings:
|
||||
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
|
||||
|
||||
@@ -14,6 +14,7 @@ $ python convert_zero123_to_diffusers.py \
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline
|
||||
@@ -38,51 +39,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
|
||||
else:
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
if (
|
||||
"unet_config" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["unet_config"] is not None
|
||||
):
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
unet_params = original_config["model"]["params"]["network_config"]["params"]
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
if unet_params["transformer_depth"] is not None:
|
||||
transformer_layers_per_block = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
unet_params["transformer_depth"]
|
||||
if isinstance(unet_params["transformer_depth"], int)
|
||||
else list(unet_params["transformer_depth"])
|
||||
)
|
||||
else:
|
||||
transformer_layers_per_block = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
@@ -90,13 +94,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
if unet_params["context_dim"] is not None:
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
unet_params["context_dim"]
|
||||
if isinstance(unet_params["context_dim"], int)
|
||||
else unet_params["context_dim"][0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
if unet_params["num_classes"] == "sequential":
|
||||
if context_dim in [2048, 1280]:
|
||||
# SDXL
|
||||
addition_embed_type = "text_time"
|
||||
@@ -104,16 +110,16 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
else:
|
||||
class_embed_type = "projection"
|
||||
assert "adm_in_channels" in unet_params
|
||||
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
||||
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params["num_classes"]}")
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
@@ -125,9 +131,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
}
|
||||
|
||||
if controlnet:
|
||||
config["conditioning_channels"] = unet_params.hint_channels
|
||||
config["conditioning_channels"] = unet_params["hint_channels"]
|
||||
else:
|
||||
config["out_channels"] = unet_params.out_channels
|
||||
config["out_channels"] = unet_params["out_channels"]
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
|
||||
return config
|
||||
@@ -487,22 +493,22 @@ def create_vae_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
}
|
||||
return config
|
||||
|
||||
@@ -679,18 +685,16 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
|
||||
num_in_channels = 8
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
prediction_type = "epsilon"
|
||||
image_size = 256
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
|
||||
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
@@ -721,10 +725,10 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
|
||||
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config.model
|
||||
and "scale_factor" in original_config.model.params
|
||||
and "params" in original_config["model"]
|
||||
and "scale_factor" in original_config["model"]["params"]
|
||||
):
|
||||
vae_scaling_factor = original_config.model.params.scale_factor
|
||||
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.19.4",
|
||||
"huggingface-hub>=0.20.2",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
@@ -110,7 +110,6 @@ _deps = [
|
||||
"note_seq",
|
||||
"librosa",
|
||||
"numpy",
|
||||
"omegaconf",
|
||||
"parameterized",
|
||||
"peft>=0.6.0",
|
||||
"protobuf>=3.20.3,<4",
|
||||
@@ -213,7 +212,6 @@ extras["test"] = deps_list(
|
||||
"invisible-watermark",
|
||||
"k-diffusion",
|
||||
"librosa",
|
||||
"omegaconf",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
@@ -251,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.25.0.dev0"
|
||||
__version__ = "0.26.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -316,7 +316,7 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
@@ -668,7 +668,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
|
||||
@@ -9,7 +9,7 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.19.4",
|
||||
"huggingface-hub": "huggingface-hub>=0.20.2",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
@@ -22,7 +22,6 @@ deps = {
|
||||
"note_seq": "note_seq",
|
||||
"librosa": "librosa",
|
||||
"numpy": "numpy",
|
||||
"omegaconf": "omegaconf",
|
||||
"parameterized": "parameterized",
|
||||
"peft": "peft>=0.6.0",
|
||||
"protobuf": "protobuf>=3.20.3,<4",
|
||||
|
||||
@@ -49,10 +49,9 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
env,
|
||||
):
|
||||
super().__init__()
|
||||
self.value_function = value_function
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.env = env
|
||||
|
||||
self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
|
||||
|
||||
self.data = env.get_dataset()
|
||||
self.means = {}
|
||||
for key in self.data.keys():
|
||||
|
||||
@@ -33,14 +33,7 @@ PipelineImageInput = Union[
|
||||
List[torch.FloatTensor],
|
||||
]
|
||||
|
||||
PipelineDepthInput = Union[
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
torch.FloatTensor,
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
List[torch.FloatTensor],
|
||||
]
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
|
||||
class VaeImageProcessor(ConfigMixin):
|
||||
@@ -169,7 +162,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
@staticmethod
|
||||
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
||||
"""
|
||||
Blurs an image.
|
||||
Applies Gaussian blur to an image.
|
||||
"""
|
||||
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
||||
|
||||
@@ -402,6 +395,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
|
||||
return image
|
||||
|
||||
def get_default_height_width(
|
||||
@@ -634,7 +628,9 @@ class VaeImageProcessor(ConfigMixin):
|
||||
init_image_masked = init_image_masked.convert("RGBA")
|
||||
|
||||
if crop_coords is not None:
|
||||
x, y, w, h = crop_coords
|
||||
x, y, x2, y2 = crop_coords
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
base_image = PIL.Image.new("RGBA", (width, height))
|
||||
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
||||
base_image.paste(image, (x, y))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
|
||||
from ..utils.import_utils import is_torch_available, is_transformers_available
|
||||
from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
@@ -64,6 +64,8 @@ if is_torch_available():
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
@@ -76,6 +78,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
from .peft import PeftAdapterMixin
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
@@ -132,21 +132,23 @@ class IPAdapterMixin:
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# load CLIP image encoer here if it has not been registered to the pipeline yet
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=os.path.join(subfolder, "image_encoder"),
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
@@ -157,3 +159,32 @@ class IPAdapterMixin:
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
self.unet.set_default_attn_processor()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user