Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6ca6fbd614 | |||
| 3e3d102f20 | |||
| 468ae09ed8 | |||
| 3fca52022f | |||
| c375903db5 | |||
| b9d52fca1d | |||
| 2ada094bff | |||
| 1b4c4d4614 | |||
| 28ef949cf6 | |||
| f1f542bdd4 | |||
| a9c403c001 | |||
| e7b9a0762b | |||
| 8eb17315c8 | |||
| c71c19c5e6 | |||
| adc31940a9 | |||
| 963ee05d16 | |||
| 668e34c6e0 | |||
| 25d7bb3ea6 | |||
| 394b8fb996 | |||
| a1d55e14ba | |||
| e5564d45bf | |||
| 2921a20194 | |||
| 3376252d71 | |||
| 16170c69ae | |||
| 4408047ac5 | |||
| 34fab8b511 | |||
| 298ce67999 | |||
| d2e7a19fd5 | |||
| cd3082008e | |||
| f3209b5b55 | |||
| 96399c3ec6 | |||
| 10d3220abe | |||
| f69511ecc6 | |||
| d2b10b1f4f | |||
| 23a2cd3337 | |||
| 4edde134f6 | |||
| 074a7cc3c5 | |||
| 6bfd13f07a | |||
| eeb70033a6 | |||
| c4a4750cb3 | |||
| a6375d4101 | |||
| 8e1b7a084a | |||
| 6946facf69 | |||
| 130dd936bb | |||
| a899e42fc7 | |||
| f96e4a16ad | |||
| 9c6e9684a2 | |||
| 2e4841ef1e | |||
| 8bea943714 | |||
| 614d0c64e9 | |||
| b1a2c0d577 | |||
| 06ee907b73 | |||
| 896fb6d8d7 | |||
| 7f51f286a5 | |||
| 829f6defa4 | |||
| 24bdf4b215 | |||
| 95e0c3757d | |||
| 6cf0be5d3d |
@@ -54,7 +54,7 @@ jobs:
|
|||||||
else
|
else
|
||||||
# e.g. refs/tags/v0.28.1 -> v0.28.1
|
# e.g. refs/tags/v0.28.1 -> v0.28.1
|
||||||
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
|
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
|
||||||
echo "PATH_IN_REPO=${${{ github.ref }}#refs/tags/}" >> $GITHUB_ENV
|
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
- name: Print env vars
|
- name: Print env vars
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -33,4 +33,3 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||||
pytest tests/others/test_dependencies.py
|
pytest tests/others/test_dependencies.py
|
||||||
|
|
||||||
+3
-2
@@ -245,7 +245,7 @@ The official training examples are maintained by the Diffusers' core maintainers
|
|||||||
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
|
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
|
||||||
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
|
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
|
||||||
|
|
||||||
Both official training and research examples consist of a directory that contains one or more training scripts, a requirements.txt file, and a README.md file. In order for the user to make use of the
|
Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
|
||||||
training examples, it is required to clone the repository:
|
training examples, it is required to clone the repository:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -255,7 +255,8 @@ git clone https://github.com/huggingface/diffusers
|
|||||||
as well as to install all additional dependencies required for training:
|
as well as to install all additional dependencies required for training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r /examples/<your-example-folder>/requirements.txt
|
cd diffusers
|
||||||
|
pip install -r examples/<your-example-folder>/requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
|
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
|
||||||
|
|||||||
@@ -20,21 +20,11 @@ limitations under the License.
|
|||||||
<br>
|
<br>
|
||||||
<p>
|
<p>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/huggingface/diffusers/blob/main/LICENSE">
|
<a href="https://github.com/huggingface/diffusers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/datasets.svg?color=blue"></a>
|
||||||
<img alt="GitHub" src="https://img.shields.io/github/license/huggingface/datasets.svg?color=blue">
|
<a href="https://github.com/huggingface/diffusers/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg"></a>
|
||||||
</a>
|
<a href="https://pepy.tech/project/diffusers"><img alt="GitHub release" src="https://static.pepy.tech/badge/diffusers/month"></a>
|
||||||
<a href="https://github.com/huggingface/diffusers/releases">
|
<a href="CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg"></a>
|
||||||
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg">
|
<a href="https://twitter.com/diffuserslib"><img alt="X account" src="https://img.shields.io/twitter/url/https/twitter.com/diffuserslib.svg?style=social&label=Follow%20%40diffuserslib"></a>
|
||||||
</a>
|
|
||||||
<a href="https://pepy.tech/project/diffusers">
|
|
||||||
<img alt="GitHub release" src="https://static.pepy.tech/badge/diffusers/month">
|
|
||||||
</a>
|
|
||||||
<a href="CODE_OF_CONDUCT.md">
|
|
||||||
<img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg">
|
|
||||||
</a>
|
|
||||||
<a href="https://twitter.com/diffuserslib">
|
|
||||||
<img alt="X account" src="https://img.shields.io/twitter/url/https/twitter.com/diffuserslib.svg?style=social&label=Follow%20%40diffuserslib">
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers \
|
transformers \
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers
|
transformers
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers
|
transformers
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers
|
transformers
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers
|
transformers
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers
|
transformers
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers matplotlib
|
transformers matplotlib
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers \
|
transformers \
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
|||||||
huggingface-hub \
|
huggingface-hub \
|
||||||
Jinja2 \
|
Jinja2 \
|
||||||
librosa \
|
librosa \
|
||||||
numpy \
|
numpy==1.26.4 \
|
||||||
scipy \
|
scipy \
|
||||||
tensorboard \
|
tensorboard \
|
||||||
transformers \
|
transformers \
|
||||||
|
|||||||
@@ -253,6 +253,8 @@
|
|||||||
title: PriorTransformer
|
title: PriorTransformer
|
||||||
- local: api/models/controlnet
|
- local: api/models/controlnet
|
||||||
title: ControlNetModel
|
title: ControlNetModel
|
||||||
|
- local: api/models/controlnet_sd3
|
||||||
|
title: SD3ControlNetModel
|
||||||
title: Models
|
title: Models
|
||||||
- isExpanded: false
|
- isExpanded: false
|
||||||
sections:
|
sections:
|
||||||
@@ -276,6 +278,8 @@
|
|||||||
title: Consistency Models
|
title: Consistency Models
|
||||||
- local: api/pipelines/controlnet
|
- local: api/pipelines/controlnet
|
||||||
title: ControlNet
|
title: ControlNet
|
||||||
|
- local: api/pipelines/controlnet_sd3
|
||||||
|
title: ControlNet with Stable Diffusion 3
|
||||||
- local: api/pipelines/controlnet_sdxl
|
- local: api/pipelines/controlnet_sdxl
|
||||||
title: ControlNet with Stable Diffusion XL
|
title: ControlNet with Stable Diffusion XL
|
||||||
- local: api/pipelines/controlnetxs
|
- local: api/pipelines/controlnetxs
|
||||||
|
|||||||
@@ -41,12 +41,6 @@ An attention processor is a class for applying different types of attention mech
|
|||||||
## FusedAttnProcessor2_0
|
## FusedAttnProcessor2_0
|
||||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||||
|
|
||||||
## LoRAAttnAddedKVProcessor
|
|
||||||
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
|
|
||||||
|
|
||||||
## LoRAXFormersAttnProcessor
|
|
||||||
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
|
|
||||||
|
|
||||||
## SlicedAttnProcessor
|
## SlicedAttnProcessor
|
||||||
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
|||||||
- [`StableDiffusionXLInstructPix2PixPipeline`]
|
- [`StableDiffusionXLInstructPix2PixPipeline`]
|
||||||
- [`StableDiffusionXLControlNetPipeline`]
|
- [`StableDiffusionXLControlNetPipeline`]
|
||||||
- [`StableDiffusionXLKDiffusionPipeline`]
|
- [`StableDiffusionXLKDiffusionPipeline`]
|
||||||
|
- [`StableDiffusion3Pipeline`]
|
||||||
- [`LatentConsistencyModelPipeline`]
|
- [`LatentConsistencyModelPipeline`]
|
||||||
- [`LatentConsistencyModelImg2ImgPipeline`]
|
- [`LatentConsistencyModelImg2ImgPipeline`]
|
||||||
- [`StableDiffusionControlNetXSPipeline`]
|
- [`StableDiffusionControlNetXSPipeline`]
|
||||||
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
|||||||
- [`StableCascadeUNet`]
|
- [`StableCascadeUNet`]
|
||||||
- [`AutoencoderKL`]
|
- [`AutoencoderKL`]
|
||||||
- [`ControlNetModel`]
|
- [`ControlNetModel`]
|
||||||
|
- [`SD3Transformer2DModel`]
|
||||||
|
|
||||||
## FromSingleFileMixin
|
## FromSingleFileMixin
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# SD3ControlNetModel
|
||||||
|
|
||||||
|
SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3.
|
||||||
|
|
||||||
|
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
|
||||||
|
|
||||||
|
The abstract from the paper is:
|
||||||
|
|
||||||
|
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||||
|
|
||||||
|
## Loading from the original format
|
||||||
|
|
||||||
|
By default the [`SD3ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import StableDiffusion3ControlNetPipeline
|
||||||
|
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
|
||||||
|
|
||||||
|
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
|
||||||
|
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet)
|
||||||
|
```
|
||||||
|
|
||||||
|
## SD3ControlNetModel
|
||||||
|
|
||||||
|
[[autodoc]] SD3ControlNetModel
|
||||||
|
|
||||||
|
## SD3ControlNetOutput
|
||||||
|
|
||||||
|
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
|
||||||
|
|
||||||
@@ -78,7 +78,6 @@ output = pipe(
|
|||||||
)
|
)
|
||||||
frames = output.frames[0]
|
frames = output.frames[0]
|
||||||
export_to_gif(frames, "animation.gif")
|
export_to_gif(frames, "animation.gif")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Here are some sample outputs:
|
Here are some sample outputs:
|
||||||
@@ -303,7 +302,6 @@ output = pipe(
|
|||||||
)
|
)
|
||||||
frames = output.frames[0]
|
frames = output.frames[0]
|
||||||
export_to_gif(frames, "animation.gif")
|
export_to_gif(frames, "animation.gif")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
@@ -378,7 +376,6 @@ output = pipe(
|
|||||||
)
|
)
|
||||||
frames = output.frames[0]
|
frames = output.frames[0]
|
||||||
export_to_gif(frames, "animation.gif")
|
export_to_gif(frames, "animation.gif")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# ControlNet with Stable Diffusion 3
|
||||||
|
|
||||||
|
StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3.
|
||||||
|
|
||||||
|
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||||
|
|
||||||
|
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||||
|
|
||||||
|
The abstract from the paper is:
|
||||||
|
|
||||||
|
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||||
|
|
||||||
|
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
## StableDiffusion3ControlNetPipeline
|
||||||
|
[[autodoc]] StableDiffusion3ControlNetPipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## StableDiffusion3PipelineOutput
|
||||||
|
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
|
||||||
@@ -186,7 +186,7 @@ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgrap
|
|||||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
# Warm Up
|
# Warm Up
|
||||||
prompt = "a photo of a cat holding a sign that says hello world",
|
prompt = "a photo of a cat holding a sign that says hello world"
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
|
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
|
||||||
|
|
||||||
@@ -197,6 +197,27 @@ image.save("sd3_hello_world.png")
|
|||||||
|
|
||||||
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
|
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
|
||||||
|
|
||||||
|
## Tiny AutoEncoder for Stable Diffusion 3
|
||||||
|
|
||||||
|
Tiny AutoEncoder for Stable Diffusion (TAESD3) is a tiny distilled version of Stable Diffusion 3's VAE by [Ollin Boer Bohan](https://github.com/madebyollin/taesd) that can decode [`StableDiffusion3Pipeline`] latents almost instantly.
|
||||||
|
|
||||||
|
To use with Stable Diffusion 3:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusion3Pipeline, AutoencoderTiny
|
||||||
|
|
||||||
|
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
|
||||||
|
pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
prompt = "slice of delicious New York-style berry cheesecake"
|
||||||
|
image = pipe(prompt, num_inference_steps=25).images[0]
|
||||||
|
image.save("cheesecake.png")
|
||||||
|
```
|
||||||
|
|
||||||
## Loading the original checkpoints via `from_single_file`
|
## Loading the original checkpoints via `from_single_file`
|
||||||
|
|
||||||
The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.
|
The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.
|
||||||
@@ -211,17 +232,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
|
|||||||
|
|
||||||
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
|
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
|
||||||
|
|
||||||
```python
|
### Loading the single file checkpoint without T5
|
||||||
from diffusers import StableDiffusion3Pipeline
|
|
||||||
from transformers import T5EncoderModel
|
|
||||||
|
|
||||||
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
|
```python
|
||||||
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
|
import torch
|
||||||
|
from diffusers import StableDiffusion3Pipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||||
|
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
text_encoder_3=None
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||||
|
image.save('sd3-single-file.png')
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip>
|
### Loading the single file checkpoint with T5
|
||||||
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
|
|
||||||
</Tip>
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusion3Pipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||||
|
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||||
|
image.save('sd3-single-file-t5-fp8.png')
|
||||||
|
```
|
||||||
|
|
||||||
## StableDiffusion3Pipeline
|
## StableDiffusion3Pipeline
|
||||||
|
|
||||||
|
|||||||
@@ -22,14 +22,13 @@ We enormously value feedback from the community, so please do not be afraid to s
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
|
You can contribute in many ways ranging from answering questions on issues and discussions to adding new diffusion models to the core library.
|
||||||
the core library.
|
|
||||||
|
|
||||||
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
|
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
|
||||||
|
|
||||||
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
|
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
|
||||||
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
|
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose) or new discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions/new/choose).
|
||||||
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
|
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues) or discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions).
|
||||||
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
|
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
|
||||||
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
||||||
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
|
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
|
||||||
@@ -63,7 +62,7 @@ In the same spirit, you are of immense help to the community by answering such q
|
|||||||
|
|
||||||
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
|
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
|
||||||
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
|
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
|
||||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||||
|
|
||||||
**NOTE about channels**:
|
**NOTE about channels**:
|
||||||
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
|
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
|
||||||
@@ -99,7 +98,7 @@ This means in more detail:
|
|||||||
- Format your code.
|
- Format your code.
|
||||||
- Do not include any external libraries except for Diffusers depending on them.
|
- Do not include any external libraries except for Diffusers depending on them.
|
||||||
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
|
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
|
||||||
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
|
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, (s)he cannot solve it.
|
||||||
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
|
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
|
||||||
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
|
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
|
||||||
|
|
||||||
@@ -288,7 +287,7 @@ The official training examples are maintained by the Diffusers' core maintainers
|
|||||||
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
|
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
|
||||||
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
|
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
|
||||||
|
|
||||||
Both official training and research examples consist of a directory that contains one or more training scripts, a requirements.txt file, and a README.md file. In order for the user to make use of the
|
Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
|
||||||
training examples, it is required to clone the repository:
|
training examples, it is required to clone the repository:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -298,7 +297,8 @@ git clone https://github.com/huggingface/diffusers
|
|||||||
as well as to install all additional dependencies required for training:
|
as well as to install all additional dependencies required for training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r /examples/<your-example-folder>/requirements.txt
|
cd diffusers
|
||||||
|
pip install -r examples/<your-example-folder>/requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
|
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
|
||||||
@@ -316,7 +316,7 @@ Once an example script works, please make sure to add a comprehensive `README.md
|
|||||||
- A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
- A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||||
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
|
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
|
||||||
|
|
||||||
If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
|
If you are contributing to the official training examples, please also make sure to add a test to its folder such as [examples/dreambooth/test_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/test_dreambooth.py). This is not necessary for non-official training examples.
|
||||||
|
|
||||||
### 8. Fixing a "Good second issue"
|
### 8. Fixing a "Good second issue"
|
||||||
|
|
||||||
@@ -418,7 +418,7 @@ You will need basic `git` proficiency to be able to contribute to
|
|||||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||||
|
|
||||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
|
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/83bc6c94eaeb6f7704a2a428931cf2d9ad973ae9/setup.py#L270)):
|
||||||
|
|
||||||
1. Fork the [repository](https://github.com/huggingface/diffusers) by
|
1. Fork the [repository](https://github.com/huggingface/diffusers) by
|
||||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||||
|
|||||||
@@ -349,7 +349,7 @@ control_image = load_image("./conditioning_image_1.png")
|
|||||||
prompt = "pale golden rod circle with old lace background"
|
prompt = "pale golden rod circle with old lace background"
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
|
image = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
|
||||||
image.save("./output.png")
|
image.save("./output.png")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
|||||||
--max_train_steps=15000 \
|
--max_train_steps=15000 \
|
||||||
--learning_rate=1e-05 \
|
--learning_rate=1e-05 \
|
||||||
--max_grad_norm=1 \
|
--max_grad_norm=1 \
|
||||||
--enable_xformers_memory_efficient_attention
|
--enable_xformers_memory_efficient_attention \
|
||||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||||
--output_dir="sd-naruto-model" \
|
--output_dir="sd-naruto-model" \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
|
|||||||
@@ -34,13 +34,10 @@ Install [PyTorch nightly](https://pytorch.org/) to benefit from the latest and f
|
|||||||
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip>
|
> [!TIP]
|
||||||
|
> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
|
||||||
|
> If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
|
||||||
|
|
||||||
The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. <br>
|
|
||||||
|
|
||||||
If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
## Baseline
|
## Baseline
|
||||||
|
|
||||||
@@ -170,6 +167,9 @@ Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3
|
|||||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
|
||||||
|
|
||||||
### Prevent graph breaks
|
### Prevent graph breaks
|
||||||
|
|
||||||
Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
|
Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
|
||||||
|
|||||||
@@ -472,7 +472,6 @@ my_local_config_path = snapshot_download(
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
print("My local config: ", my_local_config_path)
|
print("My local config: ", my_local_config_path)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
|
Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ lora_path = "lora-library/B-LoRA-pen_sketch"
|
|||||||
|
|
||||||
state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
|
state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
|
||||||
|
|
||||||
# Load traine dlora layers into the unet
|
# Load trained lora layers into the unet
|
||||||
pipeline.load_lora_into_unet(state_dict, None, pipeline.unet)
|
pipeline.load_lora_into_unet(state_dict, None, pipeline.unet)
|
||||||
|
|
||||||
#generate
|
#generate
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -326,7 +326,7 @@ def parse_args(input_args=None):
|
|||||||
type=str,
|
type=str,
|
||||||
default="TOK",
|
default="TOK",
|
||||||
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
||||||
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
|
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
|
||||||
"'TOK,TOK2,TOK3' etc.",
|
"'TOK,TOK2,TOK3' etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -559,7 +559,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -736,7 +736,7 @@ class TokenEmbeddingsHandler:
|
|||||||
# random initialization of new tokens
|
# random initialization of new tokens
|
||||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||||
|
|
||||||
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
|
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||||
|
|
||||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||||
@@ -948,7 +948,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -1967,7 +1967,7 @@ def main(args):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conver to WebUI format
|
# Convert to WebUI format
|
||||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -348,7 +348,7 @@ def parse_args(input_args=None):
|
|||||||
type=str,
|
type=str,
|
||||||
default="TOK",
|
default="TOK",
|
||||||
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
||||||
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
|
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
|
||||||
"'TOK,TOK2,TOK3' etc.",
|
"'TOK,TOK2,TOK3' etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -591,7 +591,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -824,7 +824,7 @@ class TokenEmbeddingsHandler:
|
|||||||
# random initialization of new tokens
|
# random initialization of new tokens
|
||||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||||
|
|
||||||
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
|
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||||
|
|
||||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||||
@@ -1097,7 +1097,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -1794,7 +1794,7 @@ def main(args):
|
|||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||||
# batch prompts on all training steps
|
# batch prompts on all training steps
|
||||||
else:
|
else:
|
||||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
|
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
|
||||||
@@ -2411,7 +2411,7 @@ def main(args):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conver to WebUI format
|
# Convert to WebUI format
|
||||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||||
|
|||||||
@@ -3595,7 +3595,7 @@ This pipeline provides drag-and-drop image editing using stochastic differential
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more infomation.
|
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
import PIL
|
import PIL
|
||||||
|
|||||||
@@ -24,12 +24,7 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
|||||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
@@ -1292,12 +1287,7 @@ class SDXLLongPromptWeightingPipeline(
|
|||||||
self.vae.to(dtype=torch.float32)
|
self.vae.to(dtype=torch.float32)
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
self.vae.decoder.mid_block.attentions[0].processor,
|
self.vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||||
AttnProcessor2_0,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
# to be in float32 which can save lots of memory
|
# to be in float32 which can save lots of memory
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
|
|
||||||
class MarigoldDepthOutput(BaseOutput):
|
class MarigoldDepthOutput(BaseOutput):
|
||||||
|
|||||||
@@ -16,12 +16,7 @@ from diffusers.loaders import (
|
|||||||
TextualInversionLoaderMixin,
|
TextualInversionLoaderMixin,
|
||||||
)
|
)
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
@@ -612,12 +607,7 @@ class DemoFusionSDXLPipeline(
|
|||||||
self.vae.to(dtype=torch.float32)
|
self.vae.to(dtype=torch.float32)
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
self.vae.decoder.mid_block.attentions[0].processor,
|
self.vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||||
AttnProcessor2_0,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
# to be in float32 which can save lots of memory
|
# to be in float32 which can save lots of memory
|
||||||
@@ -805,10 +795,10 @@ class DemoFusionSDXLPipeline(
|
|||||||
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C
|
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C
|
||||||
in the DemoFusion paper.
|
in the DemoFusion paper.
|
||||||
cosine_scale_3 (`float`, defaults to 1):
|
cosine_scale_3 (`float`, defaults to 1):
|
||||||
Control the strength of the gaussion filter. For specific impacts, please refer to Appendix C
|
Control the strength of the gaussian filter. For specific impacts, please refer to Appendix C
|
||||||
in the DemoFusion paper.
|
in the DemoFusion paper.
|
||||||
sigma (`float`, defaults to 1):
|
sigma (`float`, defaults to 1):
|
||||||
The standerd value of the gaussian filter.
|
The standard value of the gaussian filter.
|
||||||
show_image (`bool`, defaults to False):
|
show_image (`bool`, defaults to False):
|
||||||
Determine whether to show intermediate results during generation.
|
Determine whether to show intermediate results during generation.
|
||||||
|
|
||||||
|
|||||||
@@ -46,8 +46,6 @@ from diffusers.models.attention_processor import (
|
|||||||
Attention,
|
Attention,
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
FusedAttnProcessor2_0,
|
FusedAttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
@@ -1153,8 +1151,6 @@ class StyleAlignedSDXLPipeline(
|
|||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
FusedAttnProcessor2_0,
|
FusedAttnProcessor2_0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,12 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
|||||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
@@ -797,12 +792,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
|||||||
self.vae.to(dtype=torch.float32)
|
self.vae.to(dtype=torch.float32)
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
self.vae.decoder.mid_block.attentions[0].processor,
|
self.vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||||
AttnProcessor2_0,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
# to be in float32 which can save lots of memory
|
# to be in float32 which can save lots of memory
|
||||||
|
|||||||
@@ -44,12 +44,7 @@ from diffusers.models import (
|
|||||||
T2IAdapter,
|
T2IAdapter,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||||
@@ -1135,12 +1130,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
|||||||
self.vae.to(dtype=torch.float32)
|
self.vae.to(dtype=torch.float32)
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
self.vae.decoder.mid_block.attentions[0].processor,
|
self.vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||||
AttnProcessor2_0,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
# to be in float32 which can save lots of memory
|
# to be in float32 which can save lots of memory
|
||||||
|
|||||||
@@ -37,8 +37,6 @@ from diffusers.loaders import (
|
|||||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
@@ -854,8 +852,6 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
|||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
|||||||
@@ -34,8 +34,6 @@ from diffusers.loaders import (
|
|||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||||
@@ -662,8 +660,6 @@ class StableDiffusionXLPipelineIpex(
|
|||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -1111,11 +1111,16 @@ def main(args):
|
|||||||
|
|
||||||
# 15. LR Scheduler creation
|
# 15. LR Scheduler creation
|
||||||
# Scheduler and math around the number of training steps.
|
# Scheduler and math around the number of training steps.
|
||||||
overrode_max_train_steps = False
|
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||||
if args.max_train_steps is None:
|
if args.max_train_steps is None:
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||||
overrode_max_train_steps = True
|
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||||
|
num_training_steps_for_scheduler = (
|
||||||
|
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = (
|
args.learning_rate = (
|
||||||
@@ -1130,8 +1135,8 @@ def main(args):
|
|||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
args.lr_scheduler,
|
args.lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
num_training_steps=num_training_steps_for_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 16. Prepare for training
|
# 16. Prepare for training
|
||||||
@@ -1142,8 +1147,14 @@ def main(args):
|
|||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if args.max_train_steps is None:
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||||
|
logger.warning(
|
||||||
|
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||||
|
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||||
|
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||||
|
)
|
||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
|
|||||||
huggingface-cli login
|
huggingface-cli login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||||
|
|
||||||
## Running locally with PyTorch
|
## Running locally with PyTorch
|
||||||
|
|
||||||
### Installing the dependencies
|
### Installing the dependencies
|
||||||
@@ -106,6 +108,9 @@ To better track our training experiments, we're using the following flags in the
|
|||||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
|
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
|
||||||
|
|
||||||
@@ -113,6 +118,8 @@ To better track our training experiments, we're using the following flags in the
|
|||||||
|
|
||||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||||
|
|
||||||
|
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||||
|
|
||||||
To perform DreamBooth with LoRA, run:
|
To perform DreamBooth with LoRA, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -139,3 +146,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
|||||||
--seed="0" \
|
--seed="0" \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Other notes
|
||||||
|
|
||||||
|
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
|
||||||
@@ -4,4 +4,5 @@ transformers>=4.41.2
|
|||||||
ftfy
|
ftfy
|
||||||
tensorboard
|
tensorboard
|
||||||
Jinja2
|
Jinja2
|
||||||
peft== 0.11.1
|
peft==0.11.1
|
||||||
|
sentencepiece
|
||||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,11 @@ from diffusers import (
|
|||||||
StableDiffusion3Pipeline,
|
StableDiffusion3Pipeline,
|
||||||
)
|
)
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import cast_training_params
|
from diffusers.training_utils import (
|
||||||
|
cast_training_params,
|
||||||
|
compute_density_for_timestep_sampling,
|
||||||
|
compute_loss_weighting_for_sd3,
|
||||||
|
)
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
check_min_version,
|
check_min_version,
|
||||||
convert_unet_state_dict_to_peft,
|
convert_unet_state_dict_to_peft,
|
||||||
@@ -67,7 +71,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.28.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -298,6 +302,12 @@ def parse_args(input_args=None):
|
|||||||
default=None,
|
default=None,
|
||||||
help="The prompt to specify images in the same class as provided instance images.",
|
help="The prompt to specify images in the same class as provided instance images.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_sequence_length",
|
||||||
|
type=int,
|
||||||
|
default=77,
|
||||||
|
help="Maximum sequence length to use with with the T5 text encoder",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_prompt",
|
"--validation_prompt",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -467,11 +477,23 @@ def parse_args(input_args=None):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
|
"--weighting_scheme",
|
||||||
|
type=str,
|
||||||
|
default="logit_normal",
|
||||||
|
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.29,
|
||||||
|
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--logit_mean", type=float, default=0.0)
|
|
||||||
parser.add_argument("--logit_std", type=float, default=1.0)
|
|
||||||
parser.add_argument("--mode_scale", type=float, default=1.29)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--optimizer",
|
"--optimizer",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -495,7 +517,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -766,7 +788,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -830,6 +852,7 @@ def tokenize_prompt(tokenizer, prompt):
|
|||||||
def _encode_prompt_with_t5(
|
def _encode_prompt_with_t5(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_sequence_length,
|
||||||
prompt=None,
|
prompt=None,
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
device=None,
|
device=None,
|
||||||
@@ -840,7 +863,7 @@ def _encode_prompt_with_t5(
|
|||||||
text_inputs = tokenizer(
|
text_inputs = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=77,
|
max_length=max_sequence_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -897,6 +920,7 @@ def encode_prompt(
|
|||||||
text_encoders,
|
text_encoders,
|
||||||
tokenizers,
|
tokenizers,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
max_sequence_length,
|
||||||
device=None,
|
device=None,
|
||||||
num_images_per_prompt: int = 1,
|
num_images_per_prompt: int = 1,
|
||||||
):
|
):
|
||||||
@@ -924,6 +948,7 @@ def encode_prompt(
|
|||||||
t5_prompt_embed = _encode_prompt_with_t5(
|
t5_prompt_embed = _encode_prompt_with_t5(
|
||||||
text_encoders[-1],
|
text_encoders[-1],
|
||||||
tokenizers[-1],
|
tokenizers[-1],
|
||||||
|
max_sequence_length,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
device=device if device is not None else text_encoders[-1].device,
|
device=device if device is not None else text_encoders[-1].device,
|
||||||
@@ -1297,7 +1322,9 @@ def main(args):
|
|||||||
|
|
||||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||||
|
text_encoders, tokenizers, prompt, args.max_sequence_length
|
||||||
|
)
|
||||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
@@ -1316,6 +1343,8 @@ def main(args):
|
|||||||
# Clear the memory here
|
# Clear the memory here
|
||||||
if not train_dataset.custom_instance_prompts:
|
if not train_dataset.custom_instance_prompts:
|
||||||
del tokenizers, text_encoders
|
del tokenizers, text_encoders
|
||||||
|
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||||
|
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -1330,7 +1359,7 @@ def main(args):
|
|||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
||||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||||
# batch prompts on all training steps
|
# batch prompts on all training steps
|
||||||
else:
|
else:
|
||||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||||
@@ -1462,7 +1491,15 @@ def main(args):
|
|||||||
bsz = model_input.shape[0]
|
bsz = model_input.shape[0]
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
|
u = compute_density_for_timestep_sampling(
|
||||||
|
weighting_scheme=args.weighting_scheme,
|
||||||
|
batch_size=bsz,
|
||||||
|
logit_mean=args.logit_mean,
|
||||||
|
logit_std=args.logit_std,
|
||||||
|
mode_scale=args.mode_scale,
|
||||||
|
)
|
||||||
|
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||||
|
|
||||||
# Add noise according to flow matching.
|
# Add noise according to flow matching.
|
||||||
@@ -1482,20 +1519,11 @@ def main(args):
|
|||||||
# Preconditioning of the model outputs.
|
# Preconditioning of the model outputs.
|
||||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
|
||||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
# these weighting schemes use a uniform timestep sampling
|
||||||
if args.weighting_scheme == "sigma_sqrt":
|
# and instead post-weight the loss
|
||||||
weighting = (sigmas**-2.0).float()
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||||
elif args.weighting_scheme == "logit_normal":
|
|
||||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
|
||||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
|
||||||
weighting = torch.nn.functional.sigmoid(u)
|
|
||||||
elif args.weighting_scheme == "mode":
|
|
||||||
# See sec 3.1 in the SD3 paper (20).
|
|
||||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
|
||||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
|
||||||
|
|
||||||
# simplified flow matching aka 0-rectified flow matching loss
|
# flow matching loss
|
||||||
# target = model_input - noise
|
|
||||||
target = model_input
|
target = model_input
|
||||||
|
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -562,7 +562,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -861,7 +861,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -1289,8 +1289,8 @@ def main(args):
|
|||||||
models = [unet_]
|
models = [unet_]
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
models.extend([text_encoder_one_, text_encoder_two_])
|
models.extend([text_encoder_one_, text_encoder_two_])
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
# only upcast trainable parameters (LoRA) into fp32
|
||||||
cast_training_params(models)
|
cast_training_params(models)
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
@@ -1488,7 +1488,7 @@ def main(args):
|
|||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||||
# batch prompts on all training steps
|
# batch prompts on all training steps
|
||||||
else:
|
else:
|
||||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from diffusers import (
|
|||||||
StableDiffusion3Pipeline,
|
StableDiffusion3Pipeline,
|
||||||
)
|
)
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
check_min_version,
|
check_min_version,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
@@ -63,7 +64,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.28.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -297,6 +298,12 @@ def parse_args(input_args=None):
|
|||||||
default=None,
|
default=None,
|
||||||
help="The prompt to specify images in the same class as provided instance images.",
|
help="The prompt to specify images in the same class as provided instance images.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_sequence_length",
|
||||||
|
type=int,
|
||||||
|
default=77,
|
||||||
|
help="Maximum sequence length to use with with the T5 text encoder",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_prompt",
|
"--validation_prompt",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -465,11 +472,23 @@ def parse_args(input_args=None):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
|
"--weighting_scheme",
|
||||||
|
type=str,
|
||||||
|
default="logit_normal",
|
||||||
|
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.29,
|
||||||
|
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--logit_mean", type=float, default=0.0)
|
|
||||||
parser.add_argument("--logit_std", type=float, default=1.0)
|
|
||||||
parser.add_argument("--mode_scale", type=float, default=1.29)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--optimizer",
|
"--optimizer",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -493,7 +512,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -764,7 +783,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -828,6 +847,7 @@ def tokenize_prompt(tokenizer, prompt):
|
|||||||
def _encode_prompt_with_t5(
|
def _encode_prompt_with_t5(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_sequence_length,
|
||||||
prompt=None,
|
prompt=None,
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
device=None,
|
device=None,
|
||||||
@@ -838,7 +858,7 @@ def _encode_prompt_with_t5(
|
|||||||
text_inputs = tokenizer(
|
text_inputs = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=77,
|
max_length=max_sequence_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -895,6 +915,7 @@ def encode_prompt(
|
|||||||
text_encoders,
|
text_encoders,
|
||||||
tokenizers,
|
tokenizers,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
max_sequence_length,
|
||||||
device=None,
|
device=None,
|
||||||
num_images_per_prompt: int = 1,
|
num_images_per_prompt: int = 1,
|
||||||
):
|
):
|
||||||
@@ -922,6 +943,7 @@ def encode_prompt(
|
|||||||
t5_prompt_embed = _encode_prompt_with_t5(
|
t5_prompt_embed = _encode_prompt_with_t5(
|
||||||
text_encoders[-1],
|
text_encoders[-1],
|
||||||
tokenizers[-1],
|
tokenizers[-1],
|
||||||
|
max_sequence_length,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
device=device if device is not None else text_encoders[-1].device,
|
device=device if device is not None else text_encoders[-1].device,
|
||||||
@@ -1324,7 +1346,9 @@ def main(args):
|
|||||||
|
|
||||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||||
|
text_encoders, tokenizers, prompt, args.max_sequence_length
|
||||||
|
)
|
||||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
@@ -1347,6 +1371,8 @@ def main(args):
|
|||||||
# Clear the memory here
|
# Clear the memory here
|
||||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||||
del tokenizers, text_encoders
|
del tokenizers, text_encoders
|
||||||
|
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||||
|
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -1362,7 +1388,7 @@ def main(args):
|
|||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
||||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||||
# batch prompts on all training steps
|
# batch prompts on all training steps
|
||||||
else:
|
else:
|
||||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||||
@@ -1526,7 +1552,15 @@ def main(args):
|
|||||||
bsz = model_input.shape[0]
|
bsz = model_input.shape[0]
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
|
u = compute_density_for_timestep_sampling(
|
||||||
|
weighting_scheme=args.weighting_scheme,
|
||||||
|
batch_size=bsz,
|
||||||
|
logit_mean=args.logit_mean,
|
||||||
|
logit_std=args.logit_std,
|
||||||
|
mode_scale=args.mode_scale,
|
||||||
|
)
|
||||||
|
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||||
|
|
||||||
# Add noise according to flow matching.
|
# Add noise according to flow matching.
|
||||||
@@ -1560,21 +1594,11 @@ def main(args):
|
|||||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||||
# Preconditioning of the model outputs.
|
# Preconditioning of the model outputs.
|
||||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
# these weighting schemes use a uniform timestep sampling
|
||||||
|
# and instead post-weight the loss
|
||||||
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||||
|
|
||||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
# flow matching loss
|
||||||
if args.weighting_scheme == "sigma_sqrt":
|
|
||||||
weighting = (sigmas**-2.0).float()
|
|
||||||
elif args.weighting_scheme == "logit_normal":
|
|
||||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
|
||||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
|
||||||
weighting = torch.nn.functional.sigmoid(u)
|
|
||||||
elif args.weighting_scheme == "mode":
|
|
||||||
# See sec 3.1 in the SD3 paper (20).
|
|
||||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
|
||||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
|
||||||
|
|
||||||
# simplified flow matching aka 0-rectified flow matching loss
|
|
||||||
# target = model_input - noise
|
|
||||||
target = model_input
|
target = model_input
|
||||||
|
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ input_image_path = "/path/to/input_image"
|
|||||||
input_image = Image.open(input_image_path)
|
input_image = Image.open(input_image_path)
|
||||||
edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
|
edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
|
||||||
edited_images[0].show()
|
edited_images[0].show()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:aut
|
|||||||
numactl --membind <node N> -C <cpu list> python python inference_bf16.py
|
numactl --membind <node N> -C <cpu list> python python inference_bf16.py
|
||||||
# Launch with DPMSolverMultistepScheduler
|
# Launch with DPMSolverMultistepScheduler
|
||||||
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
|
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Accelerating the inference for Stable Diffusion using INT8
|
## Accelerating the inference for Stable Diffusion using INT8
|
||||||
|
|||||||
+5
-5
@@ -561,7 +561,7 @@ def parse_args(input_args=None):
|
|||||||
"--prodigy_beta3",
|
"--prodigy_beta3",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
||||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||||
@@ -880,7 +880,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
else: # costum prompts were provided, but length does not match size of image dataset
|
else: # custom prompts were provided, but length does not match size of image dataset
|
||||||
example["instance_prompt"] = self.instance_prompt
|
example["instance_prompt"] = self.instance_prompt
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
@@ -1363,8 +1363,8 @@ def main(args):
|
|||||||
models = [unet_]
|
models = [unet_]
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
models.extend([text_encoder_one_, text_encoder_two_])
|
models.extend([text_encoder_one_, text_encoder_two_])
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
# only upcast trainable parameters (LoRA) into fp32
|
||||||
cast_training_params(models)
|
cast_training_params(models)
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
@@ -1561,7 +1561,7 @@ def main(args):
|
|||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||||
# batch prompts on all training steps
|
# batch prompts on all training steps
|
||||||
else:
|
else:
|
||||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n
|
|||||||
|
|
||||||
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
|
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
|
||||||
|
|
||||||
|
|
||||||
|
#### Training with EMA weights
|
||||||
|
|
||||||
|
Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.
|
||||||
|
|
||||||
|
EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.
|
||||||
|
|
||||||
#### Training with DREAM
|
#### Training with DREAM
|
||||||
|
|
||||||
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
|
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Training with LoRA
|
## Training with LoRA
|
||||||
|
|
||||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||||
|
|||||||
@@ -239,7 +239,6 @@ accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lor
|
|||||||
--seed=1234 \
|
--seed=1234 \
|
||||||
--output_dir="sd-naruto-model-lora-sdxl" \
|
--output_dir="sd-naruto-model-lora-sdxl" \
|
||||||
--validation_prompt="cute dragon creature"
|
--validation_prompt="cute dragon creature"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
@@ -387,6 +387,8 @@ def parse_args():
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||||
|
parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
|
||||||
|
parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non_ema_revision",
|
"--non_ema_revision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -624,7 +626,12 @@ def main():
|
|||||||
ema_unet = UNet2DConditionModel.from_pretrained(
|
ema_unet = UNet2DConditionModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||||
)
|
)
|
||||||
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
|
ema_unet = EMAModel(
|
||||||
|
ema_unet.parameters(),
|
||||||
|
model_cls=UNet2DConditionModel,
|
||||||
|
model_config=ema_unet.config,
|
||||||
|
foreach=args.foreach_ema,
|
||||||
|
)
|
||||||
|
|
||||||
if args.enable_xformers_memory_efficient_attention:
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
@@ -655,9 +662,14 @@ def main():
|
|||||||
|
|
||||||
def load_model_hook(models, input_dir):
|
def load_model_hook(models, input_dir):
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
load_model = EMAModel.from_pretrained(
|
||||||
|
os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
|
||||||
|
)
|
||||||
ema_unet.load_state_dict(load_model.state_dict())
|
ema_unet.load_state_dict(load_model.state_dict())
|
||||||
ema_unet.to(accelerator.device)
|
if args.offload_ema:
|
||||||
|
ema_unet.pin_memory()
|
||||||
|
else:
|
||||||
|
ema_unet.to(accelerator.device)
|
||||||
del load_model
|
del load_model
|
||||||
|
|
||||||
for _ in range(len(models)):
|
for _ in range(len(models)):
|
||||||
@@ -833,7 +845,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.to(accelerator.device)
|
if args.offload_ema:
|
||||||
|
ema_unet.pin_memory()
|
||||||
|
else:
|
||||||
|
ema_unet.to(accelerator.device)
|
||||||
|
|
||||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||||
@@ -1011,7 +1026,11 @@ def main():
|
|||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
if args.offload_ema:
|
||||||
|
ema_unet.to(device="cuda", non_blocking=True)
|
||||||
ema_unet.step(unet.parameters())
|
ema_unet.step(unet.parameters())
|
||||||
|
if args.offload_ema:
|
||||||
|
ema_unet.to(device="cpu", non_blocking=True)
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.29.0.dev0")
|
check_min_version("0.30.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ from setuptools import Command, find_packages, setup
|
|||||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||||
"accelerate>=0.29.3",
|
"accelerate>=0.31.0",
|
||||||
"compel==0.1.8",
|
"compel==0.1.8",
|
||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
@@ -132,7 +132,7 @@ _deps = [
|
|||||||
"tensorboard",
|
"tensorboard",
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers>=4.25.1",
|
"transformers>=4.41.2",
|
||||||
"urllib3<=2.0.0",
|
"urllib3<=2.0.0",
|
||||||
"black",
|
"black",
|
||||||
]
|
]
|
||||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.29.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.29.0.dev0"
|
__version__ = "0.30.0.dev0"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -91,6 +91,8 @@ else:
|
|||||||
"MultiAdapter",
|
"MultiAdapter",
|
||||||
"PixArtTransformer2DModel",
|
"PixArtTransformer2DModel",
|
||||||
"PriorTransformer",
|
"PriorTransformer",
|
||||||
|
"SD3ControlNetModel",
|
||||||
|
"SD3MultiControlNetModel",
|
||||||
"SD3Transformer2DModel",
|
"SD3Transformer2DModel",
|
||||||
"StableCascadeUNet",
|
"StableCascadeUNet",
|
||||||
"T2IAdapter",
|
"T2IAdapter",
|
||||||
@@ -278,6 +280,7 @@ else:
|
|||||||
"StableCascadeCombinedPipeline",
|
"StableCascadeCombinedPipeline",
|
||||||
"StableCascadeDecoderPipeline",
|
"StableCascadeDecoderPipeline",
|
||||||
"StableCascadePriorPipeline",
|
"StableCascadePriorPipeline",
|
||||||
|
"StableDiffusion3ControlNetPipeline",
|
||||||
"StableDiffusion3Img2ImgPipeline",
|
"StableDiffusion3Img2ImgPipeline",
|
||||||
"StableDiffusion3Pipeline",
|
"StableDiffusion3Pipeline",
|
||||||
"StableDiffusionAdapterPipeline",
|
"StableDiffusionAdapterPipeline",
|
||||||
@@ -501,6 +504,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
MultiAdapter,
|
MultiAdapter,
|
||||||
PixArtTransformer2DModel,
|
PixArtTransformer2DModel,
|
||||||
PriorTransformer,
|
PriorTransformer,
|
||||||
|
SD3ControlNetModel,
|
||||||
|
SD3MultiControlNetModel,
|
||||||
SD3Transformer2DModel,
|
SD3Transformer2DModel,
|
||||||
T2IAdapter,
|
T2IAdapter,
|
||||||
T5FilmDecoder,
|
T5FilmDecoder,
|
||||||
@@ -666,6 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableCascadeCombinedPipeline,
|
StableCascadeCombinedPipeline,
|
||||||
StableCascadeDecoderPipeline,
|
StableCascadeDecoderPipeline,
|
||||||
StableCascadePriorPipeline,
|
StableCascadePriorPipeline,
|
||||||
|
StableDiffusion3ControlNetPipeline,
|
||||||
StableDiffusion3Img2ImgPipeline,
|
StableDiffusion3Img2ImgPipeline,
|
||||||
StableDiffusion3Pipeline,
|
StableDiffusion3Pipeline,
|
||||||
StableDiffusionAdapterPipeline,
|
StableDiffusionAdapterPipeline,
|
||||||
|
|||||||
@@ -716,7 +716,7 @@ class LegacyConfigMixin(ConfigMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||||
# To prevent depedency import problem.
|
# To prevent dependency import problem.
|
||||||
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
||||||
|
|
||||||
# resolve remapping
|
# resolve remapping
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# 2. run `make deps_table_update`
|
# 2. run `make deps_table_update`
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"accelerate": "accelerate>=0.29.3",
|
"accelerate": "accelerate>=0.31.0",
|
||||||
"compel": "compel==0.1.8",
|
"compel": "compel==0.1.8",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
@@ -40,7 +40,7 @@ deps = {
|
|||||||
"tensorboard": "tensorboard",
|
"tensorboard": "tensorboard",
|
||||||
"torch": "torch>=1.4",
|
"torch": "torch>=1.4",
|
||||||
"torchvision": "torchvision",
|
"torchvision": "torchvision",
|
||||||
"transformers": "transformers>=4.25.1",
|
"transformers": "transformers>=4.41.2",
|
||||||
"urllib3": "urllib3<=2.0.0",
|
"urllib3": "urllib3<=2.0.0",
|
||||||
"black": "black",
|
"black": "black",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin):
|
|||||||
|
|
||||||
channel = image.shape[1]
|
channel = image.shape[1]
|
||||||
# don't need any preprocess if the image is latents
|
# don't need any preprocess if the image is latents
|
||||||
if channel == 4:
|
if channel == self.vae_latent_channels:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
height, width = self.get_default_height_width(image, height, width)
|
height, width = self.get_default_height_width(image, height, width)
|
||||||
@@ -585,7 +585,6 @@ class VaeImageProcessor(ConfigMixin):
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
do_normalize = False
|
do_normalize = False
|
||||||
|
|
||||||
if do_normalize:
|
if do_normalize:
|
||||||
image = self.normalize(image)
|
image = self.normalize(image)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from ..utils import (
|
|||||||
_get_model_file,
|
_get_model_file,
|
||||||
convert_state_dict_to_diffusers,
|
convert_state_dict_to_diffusers,
|
||||||
convert_state_dict_to_peft,
|
convert_state_dict_to_peft,
|
||||||
|
convert_unet_state_dict_to_peft,
|
||||||
delete_adapter_layers,
|
delete_adapter_layers,
|
||||||
get_adapter_name,
|
get_adapter_name,
|
||||||
get_peft_kwargs,
|
get_peft_kwargs,
|
||||||
@@ -42,7 +43,7 @@ from ..utils import (
|
|||||||
set_adapter_layers,
|
set_adapter_layers,
|
||||||
set_weights_and_activate_adapters,
|
set_weights_and_activate_adapters,
|
||||||
)
|
)
|
||||||
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
@@ -287,7 +288,7 @@ class LoraLoaderMixin:
|
|||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
# use unet config to remap block numbers
|
# use unet config to remap block numbers
|
||||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
return state_dict, network_alphas
|
return state_dict, network_alphas
|
||||||
|
|
||||||
@@ -462,17 +463,18 @@ class LoraLoaderMixin:
|
|||||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||||
|
|
||||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
rank_key = f"{name}.{module}.lora_B.weight"
|
||||||
|
if rank_key not in text_encoder_lora_state_dict:
|
||||||
|
continue
|
||||||
|
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||||
|
|
||||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||||
if patch_mlp:
|
for module in ("fc1", "fc2"):
|
||||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
rank_key = f"{name}.{module}.lora_B.weight"
|
||||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
if rank_key not in text_encoder_lora_state_dict:
|
||||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
continue
|
||||||
|
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
|
||||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
|
||||||
|
|
||||||
if network_alphas is not None:
|
if network_alphas is not None:
|
||||||
alpha_keys = [
|
alpha_keys = [
|
||||||
@@ -1542,6 +1544,11 @@ class SD3LoraLoaderMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(state_dict.keys()) > 0:
|
if len(state_dict.keys()) > 0:
|
||||||
|
# check with first key if is not in peft format
|
||||||
|
first_key = next(iter(state_dict.keys()))
|
||||||
|
if "lora_A" not in first_key:
|
||||||
|
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||||
|
|
||||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||||
@@ -1727,3 +1734,78 @@ class SD3LoraLoaderMixin:
|
|||||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||||
|
|
||||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||||
|
|
||||||
|
def fuse_lora(
|
||||||
|
self,
|
||||||
|
fuse_transformer: bool = True,
|
||||||
|
lora_scale: float = 1.0,
|
||||||
|
safe_fusing: bool = False,
|
||||||
|
adapter_names: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental API.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters.
|
||||||
|
lora_scale (`float`, defaults to 1.0):
|
||||||
|
Controls how much to influence the outputs with the LoRA parameters.
|
||||||
|
safe_fusing (`bool`, defaults to `False`):
|
||||||
|
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||||
|
adapter_names (`List[str]`, *optional*):
|
||||||
|
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
pipeline.load_lora_weights(
|
||||||
|
"nerijs/pixel-art-medium-128-v0.1",
|
||||||
|
weight_name="pixel-art-medium-128-v0.1.safetensors",
|
||||||
|
adapter_name="pixel",
|
||||||
|
)
|
||||||
|
pipeline.fuse_lora(lora_scale=0.7)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if fuse_transformer:
|
||||||
|
self.num_fused_loras += 1
|
||||||
|
|
||||||
|
if fuse_transformer:
|
||||||
|
transformer = (
|
||||||
|
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||||
|
)
|
||||||
|
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||||
|
|
||||||
|
def unfuse_lora(self, unfuse_transformer: bool = True):
|
||||||
|
r"""
|
||||||
|
Reverses the effect of
|
||||||
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental API.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters.
|
||||||
|
"""
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
|
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||||
|
if unfuse_transformer:
|
||||||
|
for module in transformer.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
module.unmerge()
|
||||||
|
|
||||||
|
self.num_fused_loras -= 1
|
||||||
|
|||||||
@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
||||||
|
"""
|
||||||
|
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (`dict`): The state dict to convert.
|
||||||
|
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
|
||||||
|
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
|
||||||
|
"text_encoder".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
|
||||||
|
"""
|
||||||
unet_state_dict = {}
|
unet_state_dict = {}
|
||||||
te_state_dict = {}
|
te_state_dict = {}
|
||||||
te2_state_dict = {}
|
te2_state_dict = {}
|
||||||
network_alphas = {}
|
network_alphas = {}
|
||||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
|
||||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
|
||||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
|
||||||
|
|
||||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
# Check for DoRA-enabled LoRAs.
|
||||||
|
if any(
|
||||||
|
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
|
||||||
|
for k in state_dict
|
||||||
|
):
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
# Iterate over all LoRA weights.
|
||||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
all_lora_keys = list(state_dict.keys())
|
||||||
for key in lora_keys:
|
for key in all_lora_keys:
|
||||||
|
if not key.endswith("lora_down.weight"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract LoRA name.
|
||||||
lora_name = key.split(".")[0]
|
lora_name = key.split(".")[0]
|
||||||
|
|
||||||
|
# Find corresponding up weight and alpha.
|
||||||
lora_name_up = lora_name + ".lora_up.weight"
|
lora_name_up = lora_name + ".lora_up.weight"
|
||||||
lora_name_alpha = lora_name + ".alpha"
|
lora_name_alpha = lora_name + ".alpha"
|
||||||
|
|
||||||
|
# Handle U-Net LoRAs.
|
||||||
if lora_name.startswith("lora_unet_"):
|
if lora_name.startswith("lora_unet_"):
|
||||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
diffusers_name = _convert_unet_lora_key(key)
|
||||||
|
|
||||||
if "input.blocks" in diffusers_name:
|
# Store down and up weights.
|
||||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||||
else:
|
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
|
||||||
|
|
||||||
if "middle.block" in diffusers_name:
|
# Store DoRA scale if present.
|
||||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
if "dora_scale" in state_dict:
|
||||||
else:
|
|
||||||
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
|
||||||
if "output.blocks" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
|
||||||
else:
|
|
||||||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
|
||||||
|
|
||||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
|
||||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
|
||||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
|
||||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
|
||||||
|
|
||||||
# SDXL specificity.
|
|
||||||
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
|
||||||
pattern = r"\.\d+(?=\D*$)"
|
|
||||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
|
||||||
if ".in." in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
|
||||||
if ".out." in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
|
||||||
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("op", "conv")
|
|
||||||
if "skip" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
|
||||||
|
|
||||||
# LyCORIS specificity.
|
|
||||||
if "time.emb.proj" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
|
||||||
if "conv.shortcut" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
|
||||||
|
|
||||||
# General coverage.
|
|
||||||
if "transformer_blocks" in diffusers_name:
|
|
||||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
|
||||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif "ff" in diffusers_name:
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
|
|
||||||
if is_unet_dora_lora:
|
|
||||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||||
unet_state_dict[
|
unet_state_dict[
|
||||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||||
|
|
||||||
|
# Handle text encoder LoRAs.
|
||||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||||
|
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
||||||
|
|
||||||
|
# Store down and up weights for te or te2.
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||||
|
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||||
else:
|
else:
|
||||||
key_to_replace = "lora_te2_"
|
|
||||||
|
|
||||||
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
|
||||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
|
||||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
|
||||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
|
||||||
|
|
||||||
if "self_attn" in diffusers_name:
|
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif "mlp" in diffusers_name:
|
|
||||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
|
||||||
# not utilize it yet.
|
|
||||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
# OneTrainer specificity
|
|
||||||
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
|
|
||||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||||
|
|
||||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
# Store DoRA scale if present.
|
||||||
|
if "dora_scale" in state_dict:
|
||||||
dora_scale_key_to_replace_te = (
|
dora_scale_key_to_replace_te = (
|
||||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||||
)
|
)
|
||||||
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|||||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||||
|
|
||||||
# Rename the alphas so that they can be mapped appropriately.
|
# Store alpha if present.
|
||||||
if lora_name_alpha in state_dict:
|
if lora_name_alpha in state_dict:
|
||||||
alpha = state_dict.pop(lora_name_alpha).item()
|
alpha = state_dict.pop(lora_name_alpha).item()
|
||||||
if lora_name_alpha.startswith("lora_unet_"):
|
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
|
||||||
prefix = "unet."
|
|
||||||
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
prefix = "text_encoder."
|
|
||||||
else:
|
|
||||||
prefix = "text_encoder_2."
|
|
||||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
|
||||||
network_alphas.update({new_name: alpha})
|
|
||||||
|
|
||||||
|
# Check if any keys remain.
|
||||||
if len(state_dict) > 0:
|
if len(state_dict) > 0:
|
||||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||||
|
|
||||||
logger.info("Kohya-style checkpoint detected.")
|
logger.info("Kohya-style checkpoint detected.")
|
||||||
|
|
||||||
|
# Construct final state dict.
|
||||||
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||||
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||||
te2_state_dict = (
|
te2_state_dict = (
|
||||||
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|||||||
|
|
||||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||||
return new_state_dict, network_alphas
|
return new_state_dict, network_alphas
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_unet_lora_key(key):
|
||||||
|
"""
|
||||||
|
Converts a U-Net LoRA key to a Diffusers compatible key.
|
||||||
|
"""
|
||||||
|
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||||
|
|
||||||
|
# Replace common U-Net naming patterns.
|
||||||
|
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
||||||
|
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
||||||
|
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
||||||
|
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
||||||
|
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
||||||
|
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
||||||
|
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||||
|
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||||
|
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||||
|
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||||
|
|
||||||
|
# SDXL specific conversions.
|
||||||
|
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
||||||
|
pattern = r"\.\d+(?=\D*$)"
|
||||||
|
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||||
|
if ".in." in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
||||||
|
if ".out." in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
||||||
|
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("op", "conv")
|
||||||
|
if "skip" in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||||
|
|
||||||
|
# LyCORIS specific conversions.
|
||||||
|
if "time.emb.proj" in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
||||||
|
if "conv.shortcut" in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
||||||
|
|
||||||
|
# General conversions.
|
||||||
|
if "transformer_blocks" in diffusers_name:
|
||||||
|
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||||
|
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||||
|
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||||
|
elif "ff" in diffusers_name:
|
||||||
|
pass
|
||||||
|
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return diffusers_name
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_text_encoder_lora_key(key, lora_name):
|
||||||
|
"""
|
||||||
|
Converts a text encoder LoRA key to a Diffusers compatible key.
|
||||||
|
"""
|
||||||
|
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||||
|
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
||||||
|
else:
|
||||||
|
key_to_replace = "lora_te2_"
|
||||||
|
|
||||||
|
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
||||||
|
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||||
|
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||||
|
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||||
|
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
||||||
|
|
||||||
|
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
|
||||||
|
pass
|
||||||
|
elif "mlp" in diffusers_name:
|
||||||
|
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||||
|
# not utilize it yet.
|
||||||
|
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||||
|
return diffusers_name
|
||||||
|
|
||||||
|
|
||||||
|
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
||||||
|
"""
|
||||||
|
Gets the correct alpha name for the Diffusers model.
|
||||||
|
"""
|
||||||
|
if lora_name_alpha.startswith("lora_unet_"):
|
||||||
|
prefix = "unet."
|
||||||
|
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
||||||
|
prefix = "text_encoder."
|
||||||
|
else:
|
||||||
|
prefix = "text_encoder_2."
|
||||||
|
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||||
|
return {new_name: alpha}
|
||||||
|
|||||||
@@ -28,9 +28,11 @@ from .single_file_utils import (
|
|||||||
_legacy_load_safety_checker,
|
_legacy_load_safety_checker,
|
||||||
_legacy_load_scheduler,
|
_legacy_load_scheduler,
|
||||||
create_diffusers_clip_model_from_ldm,
|
create_diffusers_clip_model_from_ldm,
|
||||||
|
create_diffusers_t5_model_from_checkpoint,
|
||||||
fetch_diffusers_config,
|
fetch_diffusers_config,
|
||||||
fetch_original_config,
|
fetch_original_config,
|
||||||
is_clip_model_in_single_file,
|
is_clip_model_in_single_file,
|
||||||
|
is_t5_in_single_file,
|
||||||
load_single_file_checkpoint,
|
load_single_file_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
|
|||||||
is_legacy_loading=is_legacy_loading,
|
is_legacy_loading=is_legacy_loading,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||||
|
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||||
|
class_obj,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
config=cached_model_config_path,
|
||||||
|
subfolder=name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
elif is_tokenizer and is_legacy_loading:
|
elif is_tokenizer and is_legacy_loading:
|
||||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||||
|
|||||||
@@ -276,16 +276,18 @@ class FromOriginalModelMixin:
|
|||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
|
||||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(diffusers_format_checkpoint)
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||||
|
|
||||||
|
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||||
|
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||||
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||||
|
)
|
||||||
|
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
model.to(torch_dtype)
|
model.to(torch_dtype)
|
||||||
|
|||||||
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
|
|||||||
LDM_CLIP_PREFIX_TO_REMOVE = [
|
LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||||
"cond_stage_model.transformer.",
|
"cond_stage_model.transformer.",
|
||||||
"conditioner.embedders.0.transformer.",
|
"conditioner.embedders.0.transformer.",
|
||||||
"text_encoders.clip_l.transformer.",
|
|
||||||
]
|
]
|
||||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||||
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def is_open_clip_sd3_model(checkpoint):
|
def is_open_clip_sd3_model(checkpoint):
|
||||||
is_open_clip_sdxl_refiner_model(checkpoint)
|
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_open_clip_sdxl_refiner_model(checkpoint):
|
def is_open_clip_sdxl_refiner_model(checkpoint):
|
||||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
text_model_dict = {}
|
text_model_dict = {}
|
||||||
|
|
||||||
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
|
remove_prefixes = []
|
||||||
|
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
|
||||||
|
if remove_prefix:
|
||||||
|
remove_prefixes.append(remove_prefix)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
for prefix in remove_prefixes:
|
for prefix in remove_prefixes:
|
||||||
@@ -1263,8 +1268,6 @@ def convert_open_clip_checkpoint(
|
|||||||
else:
|
else:
|
||||||
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
||||||
|
|
||||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
|
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
|
||||||
|
|
||||||
@@ -1313,9 +1316,6 @@ def convert_open_clip_checkpoint(
|
|||||||
else:
|
else:
|
||||||
text_model_dict[diffusers_key] = checkpoint.get(key)
|
text_model_dict[diffusers_key] = checkpoint.get(key)
|
||||||
|
|
||||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
|
||||||
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
|
||||||
|
|
||||||
return text_model_dict
|
return text_model_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -1376,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
|
|||||||
):
|
):
|
||||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
is_clip_sd3_model(checkpoint)
|
||||||
|
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
|
||||||
|
):
|
||||||
|
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
|
||||||
|
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
|
||||||
|
|
||||||
elif is_open_clip_model(checkpoint):
|
elif is_open_clip_model(checkpoint):
|
||||||
prefix = "cond_stage_model.model."
|
prefix = "cond_stage_model.model."
|
||||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||||
@@ -1391,26 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
|
|||||||
prefix = "conditioner.embedders.0.model."
|
prefix = "conditioner.embedders.0.model."
|
||||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||||
|
|
||||||
elif is_open_clip_sd3_model(checkpoint):
|
elif (
|
||||||
prefix = "text_encoders.clip_g.transformer."
|
is_open_clip_sd3_model(checkpoint)
|
||||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
|
||||||
|
):
|
||||||
|
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
|
||||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(diffusers_format_checkpoint)
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||||
|
|
||||||
|
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||||
|
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||||
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||||
|
)
|
||||||
|
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
model.to(torch_dtype)
|
model.to(torch_dtype)
|
||||||
@@ -1755,7 +1764,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
|||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
text_model_dict = {}
|
text_model_dict = {}
|
||||||
|
|
||||||
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
|
remove_prefixes = ["text_encoders.t5xxl.transformer."]
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
for prefix in remove_prefixes:
|
for prefix in remove_prefixes:
|
||||||
@@ -1799,3 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(diffusers_format_checkpoint)
|
model.load_state_dict(diffusers_format_checkpoint)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||||
|
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||||
_import_structure["embeddings"] = ["ImageProjection"]
|
_import_structure["embeddings"] = ["ImageProjection"]
|
||||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||||
@@ -74,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
VQModel,
|
VQModel,
|
||||||
)
|
)
|
||||||
from .controlnet import ControlNetModel
|
from .controlnet import ControlNetModel
|
||||||
|
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||||
from .embeddings import ImageProjection
|
from .embeddings import ImageProjection
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from importlib import import_module
|
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -24,7 +23,6 @@ from ..image_processor import IPAdapterMaskProcessor
|
|||||||
from ..utils import deprecate, logging
|
from ..utils import deprecate, logging
|
||||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||||
from ..utils.torch_utils import maybe_allow_in_graph
|
from ..utils.torch_utils import maybe_allow_in_graph
|
||||||
from .lora import LoRALinearLayer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -259,10 +257,6 @@ class Attention(nn.Module):
|
|||||||
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
||||||
`xformers`.
|
`xformers`.
|
||||||
"""
|
"""
|
||||||
is_lora = hasattr(self, "processor") and isinstance(
|
|
||||||
self.processor,
|
|
||||||
LORA_ATTENTION_PROCESSORS,
|
|
||||||
)
|
|
||||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||||
self.processor,
|
self.processor,
|
||||||
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
||||||
@@ -274,14 +268,13 @@ class Attention(nn.Module):
|
|||||||
AttnAddedKVProcessor2_0,
|
AttnAddedKVProcessor2_0,
|
||||||
SlicedAttnAddedKVProcessor,
|
SlicedAttnAddedKVProcessor,
|
||||||
XFormersAttnAddedKVProcessor,
|
XFormersAttnAddedKVProcessor,
|
||||||
LoRAAttnAddedKVProcessor,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_memory_efficient_attention_xformers:
|
if use_memory_efficient_attention_xformers:
|
||||||
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
if is_added_kv_processor and is_custom_diffusion:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
|
f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
|
||||||
)
|
)
|
||||||
if not is_xformers_available():
|
if not is_xformers_available():
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
@@ -307,18 +300,7 @@ class Attention(nn.Module):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_lora:
|
if is_custom_diffusion:
|
||||||
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
|
||||||
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
|
||||||
processor = LoRAXFormersAttnProcessor(
|
|
||||||
hidden_size=self.processor.hidden_size,
|
|
||||||
cross_attention_dim=self.processor.cross_attention_dim,
|
|
||||||
rank=self.processor.rank,
|
|
||||||
attention_op=attention_op,
|
|
||||||
)
|
|
||||||
processor.load_state_dict(self.processor.state_dict())
|
|
||||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
|
||||||
elif is_custom_diffusion:
|
|
||||||
processor = CustomDiffusionXFormersAttnProcessor(
|
processor = CustomDiffusionXFormersAttnProcessor(
|
||||||
train_kv=self.processor.train_kv,
|
train_kv=self.processor.train_kv,
|
||||||
train_q_out=self.processor.train_q_out,
|
train_q_out=self.processor.train_q_out,
|
||||||
@@ -341,18 +323,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||||
else:
|
else:
|
||||||
if is_lora:
|
if is_custom_diffusion:
|
||||||
attn_processor_class = (
|
|
||||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
|
||||||
)
|
|
||||||
processor = attn_processor_class(
|
|
||||||
hidden_size=self.processor.hidden_size,
|
|
||||||
cross_attention_dim=self.processor.cross_attention_dim,
|
|
||||||
rank=self.processor.rank,
|
|
||||||
)
|
|
||||||
processor.load_state_dict(self.processor.state_dict())
|
|
||||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
|
||||||
elif is_custom_diffusion:
|
|
||||||
attn_processor_class = (
|
attn_processor_class = (
|
||||||
CustomDiffusionAttnProcessor2_0
|
CustomDiffusionAttnProcessor2_0
|
||||||
if hasattr(F, "scaled_dot_product_attention")
|
if hasattr(F, "scaled_dot_product_attention")
|
||||||
@@ -442,82 +413,6 @@ class Attention(nn.Module):
|
|||||||
if not return_deprecated_lora:
|
if not return_deprecated_lora:
|
||||||
return self.processor
|
return self.processor
|
||||||
|
|
||||||
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
|
||||||
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
|
||||||
# with PEFT is completed.
|
|
||||||
is_lora_activated = {
|
|
||||||
name: module.lora_layer is not None
|
|
||||||
for name, module in self.named_modules()
|
|
||||||
if hasattr(module, "lora_layer")
|
|
||||||
}
|
|
||||||
|
|
||||||
# 1. if no layer has a LoRA activated we can return the processor as usual
|
|
||||||
if not any(is_lora_activated.values()):
|
|
||||||
return self.processor
|
|
||||||
|
|
||||||
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
|
||||||
is_lora_activated.pop("add_k_proj", None)
|
|
||||||
is_lora_activated.pop("add_v_proj", None)
|
|
||||||
# 2. else it is not possible that only some layers have LoRA activated
|
|
||||||
if not all(is_lora_activated.values()):
|
|
||||||
raise ValueError(
|
|
||||||
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
|
||||||
non_lora_processor_cls_name = self.processor.__class__.__name__
|
|
||||||
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
|
|
||||||
|
|
||||||
hidden_size = self.inner_dim
|
|
||||||
|
|
||||||
# now create a LoRA attention processor from the LoRA layers
|
|
||||||
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
|
|
||||||
kwargs = {
|
|
||||||
"cross_attention_dim": self.cross_attention_dim,
|
|
||||||
"rank": self.to_q.lora_layer.rank,
|
|
||||||
"network_alpha": self.to_q.lora_layer.network_alpha,
|
|
||||||
"q_rank": self.to_q.lora_layer.rank,
|
|
||||||
"q_hidden_size": self.to_q.lora_layer.out_features,
|
|
||||||
"k_rank": self.to_k.lora_layer.rank,
|
|
||||||
"k_hidden_size": self.to_k.lora_layer.out_features,
|
|
||||||
"v_rank": self.to_v.lora_layer.rank,
|
|
||||||
"v_hidden_size": self.to_v.lora_layer.out_features,
|
|
||||||
"out_rank": self.to_out[0].lora_layer.rank,
|
|
||||||
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasattr(self.processor, "attention_op"):
|
|
||||||
kwargs["attention_op"] = self.processor.attention_op
|
|
||||||
|
|
||||||
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
|
||||||
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
|
||||||
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
|
||||||
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
|
||||||
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
|
||||||
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
|
||||||
lora_processor = lora_processor_cls(
|
|
||||||
hidden_size,
|
|
||||||
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
|
||||||
rank=self.to_q.lora_layer.rank,
|
|
||||||
network_alpha=self.to_q.lora_layer.network_alpha,
|
|
||||||
)
|
|
||||||
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
|
||||||
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
|
||||||
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
|
||||||
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
|
||||||
|
|
||||||
# only save if used
|
|
||||||
if self.add_k_proj.lora_layer is not None:
|
|
||||||
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
|
|
||||||
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
|
|
||||||
else:
|
|
||||||
lora_processor.add_k_proj_lora = None
|
|
||||||
lora_processor.add_v_proj_lora = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{lora_processor_cls} does not exist.")
|
|
||||||
|
|
||||||
return lora_processor
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -1132,9 +1027,7 @@ class JointAttnProcessor2_0:
|
|||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
query, key, value, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
@@ -1406,7 +1299,6 @@ class XFormersAttnProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class AttnProcessorNPU:
|
class AttnProcessorNPU:
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
||||||
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
||||||
@@ -2241,264 +2133,6 @@ class SpatialNorm(nn.Module):
|
|||||||
return new_f
|
return new_f
|
||||||
|
|
||||||
|
|
||||||
class LoRAAttnProcessor(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
cross_attention_dim: Optional[int] = None,
|
|
||||||
rank: int = 4,
|
|
||||||
network_alpha: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
|
||||||
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
q_rank = kwargs.pop("q_rank", None)
|
|
||||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
|
||||||
q_rank = q_rank if q_rank is not None else rank
|
|
||||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
v_rank = kwargs.pop("v_rank", None)
|
|
||||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
|
||||||
v_rank = v_rank if v_rank is not None else rank
|
|
||||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
out_rank = kwargs.pop("out_rank", None)
|
|
||||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
|
||||||
out_rank = out_rank if out_rank is not None else rank
|
|
||||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
|
||||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
|
||||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
||||||
self_cls_name = self.__class__.__name__
|
|
||||||
deprecate(
|
|
||||||
self_cls_name,
|
|
||||||
"0.26.0",
|
|
||||||
(
|
|
||||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
|
||||||
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
|
||||||
" `LoraLoaderMixin.load_lora_weights`"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
|
||||||
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
|
||||||
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
|
||||||
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
|
||||||
|
|
||||||
attn._modules.pop("processor")
|
|
||||||
attn.processor = AttnProcessor()
|
|
||||||
return attn.processor(attn, hidden_states, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAAttnProcessor2_0(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
cross_attention_dim: Optional[int] = None,
|
|
||||||
rank: int = 4,
|
|
||||||
network_alpha: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
|
||||||
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
q_rank = kwargs.pop("q_rank", None)
|
|
||||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
|
||||||
q_rank = q_rank if q_rank is not None else rank
|
|
||||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
v_rank = kwargs.pop("v_rank", None)
|
|
||||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
|
||||||
v_rank = v_rank if v_rank is not None else rank
|
|
||||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
out_rank = kwargs.pop("out_rank", None)
|
|
||||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
|
||||||
out_rank = out_rank if out_rank is not None else rank
|
|
||||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
|
||||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
|
||||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
||||||
self_cls_name = self.__class__.__name__
|
|
||||||
deprecate(
|
|
||||||
self_cls_name,
|
|
||||||
"0.26.0",
|
|
||||||
(
|
|
||||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
|
||||||
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
|
||||||
" `LoraLoaderMixin.load_lora_weights`"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
|
||||||
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
|
||||||
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
|
||||||
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
|
||||||
|
|
||||||
attn._modules.pop("processor")
|
|
||||||
attn.processor = AttnProcessor2_0()
|
|
||||||
return attn.processor(attn, hidden_states, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAXFormersAttnProcessor(nn.Module):
|
|
||||||
r"""
|
|
||||||
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`, *optional*):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`, *optional*):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
rank (`int`, defaults to 4):
|
|
||||||
The dimension of the LoRA update matrices.
|
|
||||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
|
||||||
The base
|
|
||||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
|
||||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
|
||||||
operator.
|
|
||||||
network_alpha (`int`, *optional*):
|
|
||||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
|
||||||
kwargs (`dict`):
|
|
||||||
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
cross_attention_dim: int,
|
|
||||||
rank: int = 4,
|
|
||||||
attention_op: Optional[Callable] = None,
|
|
||||||
network_alpha: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.rank = rank
|
|
||||||
self.attention_op = attention_op
|
|
||||||
|
|
||||||
q_rank = kwargs.pop("q_rank", None)
|
|
||||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
|
||||||
q_rank = q_rank if q_rank is not None else rank
|
|
||||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
v_rank = kwargs.pop("v_rank", None)
|
|
||||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
|
||||||
v_rank = v_rank if v_rank is not None else rank
|
|
||||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
out_rank = kwargs.pop("out_rank", None)
|
|
||||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
|
||||||
out_rank = out_rank if out_rank is not None else rank
|
|
||||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
|
||||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
|
||||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
||||||
self_cls_name = self.__class__.__name__
|
|
||||||
deprecate(
|
|
||||||
self_cls_name,
|
|
||||||
"0.26.0",
|
|
||||||
(
|
|
||||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
|
||||||
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
|
||||||
" `LoraLoaderMixin.load_lora_weights`"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
|
||||||
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
|
||||||
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
|
||||||
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
|
||||||
|
|
||||||
attn._modules.pop("processor")
|
|
||||||
attn.processor = XFormersAttnProcessor()
|
|
||||||
return attn.processor(attn, hidden_states, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAAttnAddedKVProcessor(nn.Module):
|
|
||||||
r"""
|
|
||||||
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
|
||||||
encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`, *optional*):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
rank (`int`, defaults to 4):
|
|
||||||
The dimension of the LoRA update matrices.
|
|
||||||
network_alpha (`int`, *optional*):
|
|
||||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
|
||||||
kwargs (`dict`):
|
|
||||||
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
cross_attention_dim: Optional[int] = None,
|
|
||||||
rank: int = 4,
|
|
||||||
network_alpha: Optional[int] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
||||||
self_cls_name = self.__class__.__name__
|
|
||||||
deprecate(
|
|
||||||
self_cls_name,
|
|
||||||
"0.26.0",
|
|
||||||
(
|
|
||||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
|
||||||
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
|
||||||
" `LoraLoaderMixin.load_lora_weights`"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
|
||||||
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
|
||||||
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
|
||||||
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
|
||||||
|
|
||||||
attn._modules.pop("processor")
|
|
||||||
attn.processor = AttnAddedKVProcessor()
|
|
||||||
return attn.processor(attn, hidden_states, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterAttnProcessor(nn.Module):
|
class IPAdapterAttnProcessor(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Attention processor for Multiple IP-Adapters.
|
Attention processor for Multiple IP-Adapters.
|
||||||
@@ -2927,19 +2561,11 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
LORA_ATTENTION_PROCESSORS = (
|
|
||||||
LoRAAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnAddedKVProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
SlicedAttnAddedKVProcessor,
|
SlicedAttnAddedKVProcessor,
|
||||||
AttnAddedKVProcessor2_0,
|
AttnAddedKVProcessor2_0,
|
||||||
XFormersAttnAddedKVProcessor,
|
XFormersAttnAddedKVProcessor,
|
||||||
LoRAAttnAddedKVProcessor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
CROSS_ATTENTION_PROCESSORS = (
|
CROSS_ATTENTION_PROCESSORS = (
|
||||||
@@ -2947,9 +2573,6 @@ CROSS_ATTENTION_PROCESSORS = (
|
|||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
LoRAAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
IPAdapterAttnProcessor,
|
IPAdapterAttnProcessor,
|
||||||
IPAdapterAttnProcessor2_0,
|
IPAdapterAttnProcessor2_0,
|
||||||
)
|
)
|
||||||
@@ -2967,9 +2590,4 @@ AttentionProcessor = Union[
|
|||||||
CustomDiffusionAttnProcessor,
|
CustomDiffusionAttnProcessor,
|
||||||
CustomDiffusionXFormersAttnProcessor,
|
CustomDiffusionXFormersAttnProcessor,
|
||||||
CustomDiffusionAttnProcessor2_0,
|
CustomDiffusionAttnProcessor2_0,
|
||||||
# deprecated
|
|
||||||
LoRAAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnAddedKVProcessor,
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|||||||
latent_shift: float = 0.5,
|
latent_shift: float = 0.5,
|
||||||
force_upcast: bool = False,
|
force_upcast: bool = False,
|
||||||
scaling_factor: float = 1.0,
|
scaling_factor: float = 1.0,
|
||||||
|
shift_factor: float = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class ControlNetOutput(BaseOutput):
|
|||||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||||
used to condition the original UNet's downsampling activations.
|
used to condition the original UNet's downsampling activations.
|
||||||
mid_down_block_re_sample (`torch.Tensor`):
|
mid_down_block_re_sample (`torch.Tensor`):
|
||||||
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
|
||||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||||
Output can be used to condition the original UNet's middle block activation.
|
Output can be used to condition the original UNet's middle block activation.
|
||||||
"""
|
"""
|
||||||
@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|||||||
@@ -0,0 +1,418 @@
|
|||||||
|
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
|
from ..models.attention import JointTransformerBlock
|
||||||
|
from ..models.attention_processor import Attention, AttentionProcessor
|
||||||
|
from ..models.modeling_utils import ModelMixin
|
||||||
|
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||||
|
from .controlnet import BaseOutput, zero_module
|
||||||
|
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||||
|
from .transformers.transformer_2d import Transformer2DModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SD3ControlNetOutput(BaseOutput):
|
||||||
|
controlnet_block_samples: Tuple[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size: int = 128,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 16,
|
||||||
|
num_layers: int = 18,
|
||||||
|
attention_head_dim: int = 64,
|
||||||
|
num_attention_heads: int = 18,
|
||||||
|
joint_attention_dim: int = 4096,
|
||||||
|
caption_projection_dim: int = 1152,
|
||||||
|
pooled_projection_dim: int = 2048,
|
||||||
|
out_channels: int = 16,
|
||||||
|
pos_embed_max_size: int = 96,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
default_out_channels = in_channels
|
||||||
|
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
self.pos_embed = PatchEmbed(
|
||||||
|
height=sample_size,
|
||||||
|
width=sample_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dim=self.inner_dim,
|
||||||
|
pos_embed_max_size=pos_embed_max_size,
|
||||||
|
)
|
||||||
|
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||||
|
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||||
|
)
|
||||||
|
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||||
|
|
||||||
|
# `attention_head_dim` is doubled to account for the mixing.
|
||||||
|
# It needs to crafted when we get the actual checkpoints.
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=self.inner_dim,
|
||||||
|
context_pre_only=False,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# controlnet_blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
pos_embed_input = PatchEmbed(
|
||||||
|
height=sample_size,
|
||||||
|
width=sample_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dim=self.inner_dim,
|
||||||
|
pos_embed_type=None,
|
||||||
|
)
|
||||||
|
self.pos_embed_input = zero_module(pos_embed_input)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||||
|
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Sets the attention processor to use [feed forward
|
||||||
|
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
chunk_size (`int`, *optional*):
|
||||||
|
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||||
|
over each tensor of dim=`dim`.
|
||||||
|
dim (`int`, *optional*, defaults to `0`):
|
||||||
|
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||||
|
or dim=1 (sequence length).
|
||||||
|
"""
|
||||||
|
if dim not in [0, 1]:
|
||||||
|
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||||
|
|
||||||
|
# By default chunk size is 1
|
||||||
|
chunk_size = chunk_size or 1
|
||||||
|
|
||||||
|
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||||
|
if hasattr(module, "set_chunk_feed_forward"):
|
||||||
|
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||||
|
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
|
||||||
|
config = transformer.config
|
||||||
|
config["num_layers"] = num_layers or config.num_layers
|
||||||
|
controlnet = cls(**config)
|
||||||
|
|
||||||
|
if load_weights_from_transformer:
|
||||||
|
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
|
||||||
|
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
|
||||||
|
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
|
||||||
|
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
|
||||||
|
|
||||||
|
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
||||||
|
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
controlnet_cond: torch.Tensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
pooled_projections: torch.FloatTensor = None,
|
||||||
|
timestep: torch.LongTensor = None,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||||
|
"""
|
||||||
|
The [`SD3Transformer2DModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||||
|
Input `hidden_states`.
|
||||||
|
controlnet_cond (`torch.Tensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for ControlNet outputs.
|
||||||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||||
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||||
|
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||||
|
from the embeddings of input conditions.
|
||||||
|
timestep ( `torch.LongTensor`):
|
||||||
|
Used to indicate denoising step.
|
||||||
|
joint_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||||
|
`tuple` where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
if joint_attention_kwargs is not None:
|
||||||
|
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||||
|
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||||
|
else:
|
||||||
|
lora_scale = 1.0
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
else:
|
||||||
|
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||||
|
)
|
||||||
|
|
||||||
|
height, width = hidden_states.shape[-2:]
|
||||||
|
|
||||||
|
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||||
|
temb = self.time_text_embed(timestep, pooled_projections)
|
||||||
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||||
|
|
||||||
|
# add
|
||||||
|
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
||||||
|
|
||||||
|
block_res_samples = ()
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||||
|
)
|
||||||
|
|
||||||
|
block_res_samples = block_res_samples + (hidden_states,)
|
||||||
|
|
||||||
|
controlnet_block_res_samples = ()
|
||||||
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||||
|
block_res_sample = controlnet_block(block_res_sample)
|
||||||
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (controlnet_block_res_samples,)
|
||||||
|
|
||||||
|
return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
||||||
|
|
||||||
|
|
||||||
|
class SD3MultiControlNetModel(ModelMixin):
|
||||||
|
r"""
|
||||||
|
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
|
||||||
|
|
||||||
|
This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
|
||||||
|
compatible with `SD3ControlNetModel`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
controlnets (`List[SD3ControlNetModel]`):
|
||||||
|
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||||
|
`SD3ControlNetModel` as a list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, controlnets):
|
||||||
|
super().__init__()
|
||||||
|
self.nets = nn.ModuleList(controlnets)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
controlnet_cond: List[torch.tensor],
|
||||||
|
conditioning_scale: List[float],
|
||||||
|
pooled_projections: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
timestep: torch.LongTensor = None,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[SD3ControlNetOutput, Tuple]:
|
||||||
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||||
|
block_samples = controlnet(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
pooled_projections=pooled_projections,
|
||||||
|
controlnet_cond=image,
|
||||||
|
conditioning_scale=scale,
|
||||||
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge samples
|
||||||
|
if i == 0:
|
||||||
|
control_block_samples = block_samples
|
||||||
|
else:
|
||||||
|
control_block_samples = [
|
||||||
|
control_block_sample + block_sample
|
||||||
|
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
||||||
|
]
|
||||||
|
control_block_samples = (tuple(control_block_samples),)
|
||||||
|
|
||||||
|
return control_block_samples
|
||||||
@@ -114,6 +114,7 @@ def get_down_block_adapter(
|
|||||||
cross_attention_dim: Optional[int] = 1024,
|
cross_attention_dim: Optional[int] = 1024,
|
||||||
add_downsample: bool = True,
|
add_downsample: bool = True,
|
||||||
upcast_attention: Optional[bool] = False,
|
upcast_attention: Optional[bool] = False,
|
||||||
|
use_linear_projection: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
num_layers = 2 # only support sd + sdxl
|
num_layers = 2 # only support sd + sdxl
|
||||||
|
|
||||||
@@ -152,7 +153,7 @@ def get_down_block_adapter(
|
|||||||
in_channels=ctrl_out_channels,
|
in_channels=ctrl_out_channels,
|
||||||
num_layers=transformer_layers_per_block[i],
|
num_layers=transformer_layers_per_block[i],
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
|
||||||
)
|
)
|
||||||
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
|
|||||||
num_attention_heads: Optional[int] = 1,
|
num_attention_heads: Optional[int] = 1,
|
||||||
cross_attention_dim: Optional[int] = 1024,
|
cross_attention_dim: Optional[int] = 1024,
|
||||||
upcast_attention: bool = False,
|
upcast_attention: bool = False,
|
||||||
|
use_linear_projection: bool = True,
|
||||||
):
|
):
|
||||||
# Before the midblock application, information is concatted from base to control.
|
# Before the midblock application, information is concatted from base to control.
|
||||||
# Concat doesn't require change in number of channels
|
# Concat doesn't require change in number of channels
|
||||||
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
|
|||||||
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
|
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||||
upcast_attention: bool = True,
|
upcast_attention: bool = True,
|
||||||
max_norm_num_groups: int = 32,
|
max_norm_num_groups: int = 32,
|
||||||
|
use_linear_projection: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
cross_attention_dim=cross_attention_dim[i],
|
cross_attention_dim=cross_attention_dim[i],
|
||||||
add_downsample=not is_final_block,
|
add_downsample=not is_final_block,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
num_attention_heads=num_attention_heads[-1],
|
num_attention_heads=num_attention_heads[-1],
|
||||||
cross_attention_dim=cross_attention_dim[-1],
|
cross_attention_dim=cross_attention_dim[-1],
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# up
|
# up
|
||||||
@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
transformer_layers_per_block=unet.config.transformer_layers_per_block,
|
transformer_layers_per_block=unet.config.transformer_layers_per_block,
|
||||||
upcast_attention=unet.config.upcast_attention,
|
upcast_attention=unet.config.upcast_attention,
|
||||||
max_norm_num_groups=unet.config.norm_num_groups,
|
max_norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
|
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
|
||||||
@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
addition_embed_type: Optional[str] = None,
|
addition_embed_type: Optional[str] = None,
|
||||||
addition_time_embed_dim: Optional[int] = None,
|
addition_time_embed_dim: Optional[int] = None,
|
||||||
upcast_attention: bool = True,
|
upcast_attention: bool = True,
|
||||||
|
use_linear_projection: bool = True,
|
||||||
time_cond_proj_dim: Optional[int] = None,
|
time_cond_proj_dim: Optional[int] = None,
|
||||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
# additional controlnet configs
|
# additional controlnet configs
|
||||||
@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
cond_proj_dim=time_cond_proj_dim,
|
cond_proj_dim=time_cond_proj_dim,
|
||||||
)
|
)
|
||||||
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
|
if ctrl_learn_time_embedding:
|
||||||
|
self.ctrl_time_embedding = TimestepEmbedding(
|
||||||
|
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ctrl_time_embedding = None
|
||||||
|
|
||||||
if addition_embed_type is None:
|
if addition_embed_type is None:
|
||||||
self.base_add_time_proj = None
|
self.base_add_time_proj = None
|
||||||
@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
cross_attention_dim=cross_attention_dim[i],
|
cross_attention_dim=cross_attention_dim[i],
|
||||||
add_downsample=not is_final_block,
|
add_downsample=not is_final_block,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
|
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
|
||||||
cross_attention_dim=cross_attention_dim[-1],
|
cross_attention_dim=cross_attention_dim[-1],
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # Create up blocks
|
# # Create up blocks
|
||||||
@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
add_upsample=not is_final_block,
|
add_upsample=not is_final_block,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
norm_num_groups=norm_num_groups,
|
norm_num_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
"addition_embed_type",
|
"addition_embed_type",
|
||||||
"addition_time_embed_dim",
|
"addition_time_embed_dim",
|
||||||
"upcast_attention",
|
"upcast_attention",
|
||||||
|
"use_linear_projection",
|
||||||
"time_cond_proj_dim",
|
"time_cond_proj_dim",
|
||||||
"projection_class_embeddings_input_dim",
|
"projection_class_embeddings_input_dim",
|
||||||
]
|
]
|
||||||
@@ -864,7 +880,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
cross_attention_dim: Optional[int] = 1024,
|
cross_attention_dim: Optional[int] = 1024,
|
||||||
add_downsample: bool = True,
|
add_downsample: bool = True,
|
||||||
upcast_attention: Optional[bool] = False,
|
upcast_attention: Optional[bool] = False,
|
||||||
|
use_linear_projection: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
base_resnets = []
|
base_resnets = []
|
||||||
@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
in_channels=base_out_channels,
|
in_channels=base_out_channels,
|
||||||
num_layers=transformer_layers_per_block[i],
|
num_layers=transformer_layers_per_block[i],
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
norm_num_groups=norm_num_groups,
|
norm_num_groups=norm_num_groups,
|
||||||
)
|
)
|
||||||
@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
in_channels=ctrl_out_channels,
|
in_channels=ctrl_out_channels,
|
||||||
num_layers=transformer_layers_per_block[i],
|
num_layers=transformer_layers_per_block[i],
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
|
||||||
)
|
)
|
||||||
@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
|
||||||
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
|
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
|
||||||
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
|
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
|
||||||
|
use_linear_projection = base_downblock.attentions[0].use_linear_projection
|
||||||
else:
|
else:
|
||||||
has_crossattn = False
|
has_crossattn = False
|
||||||
transformer_layers_per_block = None
|
transformer_layers_per_block = None
|
||||||
@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
ctrl_num_attention_heads = None
|
ctrl_num_attention_heads = None
|
||||||
cross_attention_dim = None
|
cross_attention_dim = None
|
||||||
upcast_attention = None
|
upcast_attention = None
|
||||||
|
use_linear_projection = None
|
||||||
add_downsample = base_downblock.downsamplers is not None
|
add_downsample = base_downblock.downsamplers is not None
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
add_downsample=add_downsample,
|
add_downsample=add_downsample,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # load weights
|
# # load weights
|
||||||
@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|||||||
ctrl_num_attention_heads: Optional[int] = 1,
|
ctrl_num_attention_heads: Optional[int] = 1,
|
||||||
cross_attention_dim: Optional[int] = 1024,
|
cross_attention_dim: Optional[int] = 1024,
|
||||||
upcast_attention: bool = False,
|
upcast_attention: bool = False,
|
||||||
|
use_linear_projection: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
num_attention_heads=base_num_attention_heads,
|
num_attention_heads=base_num_attention_heads,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|||||||
),
|
),
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
num_attention_heads=ctrl_num_attention_heads,
|
num_attention_heads=ctrl_num_attention_heads,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|||||||
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
|
||||||
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
|
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
|
||||||
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
|
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
|
||||||
|
use_linear_projection = base_midblock.attentions[0].use_linear_projection
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = cls(
|
model = cls(
|
||||||
@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|||||||
ctrl_num_attention_heads=ctrl_num_attention_heads,
|
ctrl_num_attention_heads=ctrl_num_attention_heads,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load weights
|
# load weights
|
||||||
@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|||||||
cross_attention_dim: int = 1024,
|
cross_attention_dim: int = 1024,
|
||||||
add_upsample: bool = True,
|
add_upsample: bool = True,
|
||||||
upcast_attention: bool = False,
|
upcast_attention: bool = False,
|
||||||
|
use_linear_projection: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnets = []
|
resnets = []
|
||||||
@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|||||||
in_channels=out_channels,
|
in_channels=out_channels,
|
||||||
num_layers=transformer_layers_per_block[i],
|
num_layers=transformer_layers_per_block[i],
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
use_linear_projection=True,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
norm_num_groups=norm_num_groups,
|
norm_num_groups=norm_num_groups,
|
||||||
)
|
)
|
||||||
@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|||||||
num_attention_heads = get_first_cross_attention(base_upblock).heads
|
num_attention_heads = get_first_cross_attention(base_upblock).heads
|
||||||
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
|
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
|
||||||
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
|
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
|
||||||
|
use_linear_projection = base_upblock.attentions[0].use_linear_projection
|
||||||
else:
|
else:
|
||||||
has_crossattn = False
|
has_crossattn = False
|
||||||
transformer_layers_per_block = None
|
transformer_layers_per_block = None
|
||||||
num_attention_heads = None
|
num_attention_heads = None
|
||||||
cross_attention_dim = None
|
cross_attention_dim = None
|
||||||
upcast_attention = None
|
upcast_attention = None
|
||||||
|
use_linear_projection = None
|
||||||
add_upsample = base_upblock.upsamplers is not None
|
add_upsample = base_upblock.upsamplers is not None
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
add_upsample=add_upsample,
|
add_upsample=add_upsample,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load weights
|
# load weights
|
||||||
|
|||||||
@@ -980,7 +980,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
|||||||
|
|
||||||
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
||||||
|
|
||||||
# positionet with text and image infomation
|
# positionet with text and image information
|
||||||
else:
|
else:
|
||||||
phrases_masks = phrases_masks.unsqueeze(-1)
|
phrases_masks = phrases_masks.unsqueeze(-1)
|
||||||
image_masks = image_masks.unsqueeze(-1)
|
image_masks = image_masks.unsqueeze(-1)
|
||||||
@@ -1252,7 +1252,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|||||||
if not isinstance(image_embeds, list):
|
if not isinstance(image_embeds, list):
|
||||||
deprecation_message = (
|
deprecation_message = (
|
||||||
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
||||||
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
|
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
|
||||||
)
|
)
|
||||||
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
image_embeds = [image_embeds.unsqueeze(1)]
|
image_embeds = [image_embeds.unsqueeze(1)]
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||||
same device.
|
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||||
|
|
||||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||||
more information about each option see [designing a device
|
more information about each option see [designing a device
|
||||||
@@ -774,7 +774,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
else: # else let accelerate handle loading and dispatching.
|
else: # else let accelerate handle loading and dispatching.
|
||||||
# Load weights and dispatch according to the device_map
|
# Load weights and dispatch according to the device_map
|
||||||
# by default the device_map is None and the weights are loaded on the CPU
|
# by default the device_map is None and the weights are loaded on the CPU
|
||||||
|
force_hook = True
|
||||||
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
||||||
|
if device_map is None and is_sharded:
|
||||||
|
# we load the parameters on the cpu
|
||||||
|
device_map = {"": "cpu"}
|
||||||
|
force_hook = False
|
||||||
try:
|
try:
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
@@ -784,7 +789,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
offload_folder=offload_folder,
|
offload_folder=offload_folder,
|
||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
force_hooks=True,
|
force_hooks=force_hook,
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
@@ -808,12 +813,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file,
|
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||||
device_map,
|
device_map,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
offload_folder=offload_folder,
|
offload_folder=offload_folder,
|
||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
|
force_hooks=force_hook,
|
||||||
|
strict=True,
|
||||||
)
|
)
|
||||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||||
else:
|
else:
|
||||||
@@ -1162,7 +1169,7 @@ class LegacyModelMixin(ModelMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@validate_hf_hub_args
|
@validate_hf_hub_args
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||||
# To prevent depedency import problem.
|
# To prevent dependency import problem.
|
||||||
from .model_loading_utils import _fetch_remapped_cls_from_config
|
from .model_loading_utils import _fetch_remapped_cls_from_config
|
||||||
|
|
||||||
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user