Compare commits

..

5 Commits

Author SHA1 Message Date
yiyixuxu 7c55ef5000 up 2024-10-25 04:24:51 +02:00
yiyixuxu 0b76fea5dd up 2024-10-25 04:14:23 +02:00
yiyixuxu e9e92d043d up 2024-10-25 02:02:46 +02:00
yiyixuxu 0a6189eb95 add 2024-10-25 01:02:46 +02:00
yiyixuxu 85a9825449 init 2024-10-24 18:12:44 +02:00
179 changed files with 4251 additions and 17924 deletions
-56
View File
@@ -180,62 +180,6 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_big_gpu_torch_tests:
name: Torch tests on big GPU
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Selected Torch CUDA Test on big GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_big_gpu_torch_cuda_stats.txt
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_big_gpu_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on: docker-tpu
+3 -3
View File
@@ -81,7 +81,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: PyTorch CUDA checkpoint tests on Ubuntu
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
@@ -184,7 +184,7 @@ jobs:
run: |
python utils/print_env.py
- name: Run Flax TPU tests
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
@@ -232,7 +232,7 @@ jobs:
run: |
python utils/print_env.py
- name: Run ONNXRuntime CUDA tests
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
+1 -2
View File
@@ -4,13 +4,12 @@ on:
workflow_dispatch:
inputs:
runner_type:
description: 'Type of runner to test (aws-g6-4xlarge-plus: a10, aws-g4dn-2xlarge: t4, aws-g6e-xlarge-plus: L40)'
description: 'Type of runner to test (aws-g6-4xlarge-plus: a10 or aws-g4dn-2xlarge: t4)'
type: choice
required: true
options:
- aws-g6-4xlarge-plus
- aws-g4dn-2xlarge
- aws-g6e-xlarge-plus
docker_image:
description: 'Name of the Docker image'
required: true
-14
View File
@@ -188,8 +188,6 @@
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Habana Gaudi
- local: optimization/neuron
title: AWS Neuron
title: Optimized hardware
title: Accelerate inference and reduce memory
- sections:
@@ -252,8 +250,6 @@
title: SparseControlNetModel
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
@@ -270,8 +266,6 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
@@ -304,12 +298,8 @@
- sections:
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_allegro
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/consistency_decoder_vae
@@ -326,8 +316,6 @@
sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
@@ -404,8 +392,6 @@
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/mochi
title: Mochi
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/musicldm
@@ -1,30 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AllegroTransformer3DModel
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AllegroTransformer3DModel
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## AllegroTransformer3DModel
[[autodoc]] AllegroTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -1,37 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLAllegro
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLAllegro
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLAllegro
[[autodoc]] AutoencoderKLAllegro
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -1,32 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLMochi
The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLMochi
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLMochi
[[autodoc]] AutoencoderKLMochi
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
+2 -2
View File
@@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## ControlNetOutput
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
[[autodoc]] models.controlnet.ControlNetOutput
## FlaxControlNetModel
@@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## FlaxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
+1 -1
View File
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
## SD3ControlNetOutput
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
@@ -1,30 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# MochiTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
The model can be loaded with the following code snippet.
```python
from diffusers import MochiTransformer3DModel
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## MochiTransformer3DModel
[[autodoc]] MochiTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
-34
View File
@@ -1,34 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Allegro
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
The abstract from the paper is:
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## AllegroPipeline
[[autodoc]] AllegroPipeline
- all
- __call__
## AllegroPipelineOutput
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput
-36
View File
@@ -1,36 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# Mochi
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## MochiPipeline
[[autodoc]] MochiPipeline
- all
- __call__
## MochiPipelineOutput
[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput
-61
View File
@@ -1,61 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AWS Neuron
Diffusers functionalities are available on [AWS Inf2 instances](https://aws.amazon.com/ec2/instance-types/inf2/), which are EC2 instances powered by [Neuron machine learning accelerators](https://aws.amazon.com/machine-learning/inferentia/). These instances aim to provide better compute performance (higher throughput, lower latency) with good cost-efficiency, making them good candidates for AWS users to deploy diffusion models to production.
[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) is the interface between Hugging Face libraries and AWS Accelerators, including AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) and AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/). It supports many of the features in Diffusers with similar APIs, so it is easier to learn if you're already familiar with Diffusers. Once you have created an AWS Inf2 instance, install Optimum Neuron.
```bash
python -m pip install --upgrade-strategy eager optimum[neuronx]
```
<Tip>
We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.
</Tip>
The example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers.
Unlike Diffusers, you need to compile models in the pipeline to the Neuron format, `.neuron`. Launch the following command to export the model to the `.neuron` format.
```bash
optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
--batch_size 1 \
--height 1024 `# height in pixels of generated image, eg. 768, 1024` \
--width 1024 `# width in pixels of generated image, eg. 768, 1024` \
--num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \
--auto_cast matmul `# cast only matrix multiplication operations` \
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
sd_neuron_xl/
```
Now generate some images with the pre-compiled SDXL model.
```python
>>> from optimum.neuron import NeuronStableDiffusionXLPipeline
>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/")
>>> prompt = "a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k"
>>> image = stable_diffusion_xl(prompt).images[0]
```
<img
src="https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/sdxl_pig.png"
width="256"
height="256"
alt="peggy generated by sdxl on inf2"
/>
Feel free to check out more guides and examples on different use cases from the Optimum Neuron [documentation](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)!
@@ -183,7 +183,7 @@ Add the transformer model to the pipeline for denoising, but set the other model
```py
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-dev", ,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -1650,8 +1650,6 @@ def main(args):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1778,10 +1776,15 @@ def main(args):
if not args.enable_t5_ti:
# pure textual inversion - only clip
if pure_textual_inversion:
params_to_optimize = [text_parameters_one_with_lr]
params_to_optimize = [
text_parameters_one_with_lr,
]
te_idx = 0
else: # regular te training or regular pivotal for clip
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
te_idx = 1
elif args.enable_t5_ti:
# pivotal tuning of clip & t5
@@ -1804,7 +1807,9 @@ def main(args):
]
te_idx = 1
else:
params_to_optimize = [transformer_parameters_with_lr]
params_to_optimize = [
transformer_parameters_with_lr,
]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
@@ -1864,6 +1869,7 @@ def main(args):
params_to_optimize[-1]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -2192,8 +2198,8 @@ def main(args):
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
model_input.shape[2],
model_input.shape[3],
accelerator.device,
weight_dtype,
)
@@ -2247,8 +2253,8 @@ def main(args):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
@@ -67,12 +67,11 @@ from diffusers.utils import (
convert_state_dict_to_kohya,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -80,27 +79,30 @@ logger = get_logger(__name__)
def save_model_card(
repo_id: str,
use_dora: bool,
images: list = None,
base_model: str = None,
images=None,
base_model=str,
train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
instance_prompt=None,
validation_prompt=None,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
else:
widget_dict.append({"text": instance_prompt})
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -135,7 +137,24 @@ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_en
trigger_str += f"""
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
model_description = f"""
yaml = f"""---
tags:
- stable-diffusion
- stable-diffusion-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# SD1.5 LoRA DreamBooth - {repo_id}
<Gallery />
@@ -183,28 +202,8 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
inference=True,
widget=widget_dict,
)
tags = [
"text-to-image",
"diffusers",
"diffusers-training",
lora,
"template:sd-lora" "stable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def import_model_class_from_model_name_or_path(
@@ -1359,7 +1358,10 @@ def main(args):
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]
params_to_optimize = [
unet_lora_parameters_with_lr,
text_lora_parameters_one_with_lr,
]
else:
params_to_optimize = [unet_lora_parameters_with_lr]
@@ -1421,6 +1423,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -1794,6 +1794,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -947,6 +947,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
+2 -1
View File
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -969,6 +969,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
-248
View File
@@ -10,7 +10,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
@@ -74,7 +73,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
@@ -86,161 +84,6 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
## Example usages
### Adaptive Mask Inpainting
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
**Seoul National University, Naver Webtoon**
Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.
Here is the demonstration of Adaptive Mask Inpainting:
<video controls>
<source src="https://snuvclab.github.io/coma/static/videos/adaptive_mask_inpainting_vis.mp4" type="video/mp4">
Your browser does not support the video tag.
</video>
![teaser-img](https://snuvclab.github.io/coma/static/images/example_result_adaptive_mask_inpainting.png)
You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma).
#### Usage example
First, clone the diffusers github repository, and run the following command to set environment.
```Shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers
conda create --name ami python=3.9 -y
conda activate ami
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
pip install easydict
pip install diffusers==0.20.2 accelerate safetensors transformers
pip install setuptools==59.5.0
pip install opencv-python
pip install numpy==1.24.1
```
Then, run the below code under 'diffusers' directory.
```python
import numpy as np
import torch
from PIL import Image
from diffusers import DDIMScheduler
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE
print(AMI_INSTALL_MESSAGE)
from easydict import EasyDict
if __name__ == "__main__":
"""
Download Necessary Files
"""
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
output_file = "model_final_edd263.pkl",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
output_file = "input_img.png",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
output_file = "input_mask.png",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
output_file = "Base-PointRend-RCNN-FPN.yaml",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
output_file = "Base-RCNN-FPN.yaml",
exist_ok=True,
)
"""
Prepare Adaptive Mask Inpainting Pipeline
"""
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_steps = 50
# Scheduler
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False
)
scheduler.set_timesteps(num_inference_steps=num_steps)
## load models as pipelines
pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
scheduler=scheduler,
torch_dtype=torch.float16,
requires_safety_checker=False
).to(device)
## disable safety checker
enable_safety_checker = False
if not enable_safety_checker:
pipeline.safety_checker = None
"""
Run Adaptive Mask Inpainting
"""
default_mask_image = Image.open("./input_mask.png").convert("L")
init_image = Image.open("./input_img.png").convert("RGB")
seed = 59
generator = torch.Generator(device=device)
generator.manual_seed(seed)
image = pipeline(
prompt="a man sitting on a couch",
negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
image=init_image,
default_mask_image=default_mask_image,
guidance_scale=11.0,
strength=0.98,
use_adaptive_mask=True,
generator=generator,
enforce_full_mask_ratio=0.0,
visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!!
human_detection_thres=0.015,
).images[0]
image.save(f'final_img.png')
```
#### [Troubleshooting]
If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`.
For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.
### Flux with CFG
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
@@ -4602,94 +4445,3 @@ grid_image.save(grid_dir + "sample.png")
`pag_scale` : guidance scale of PAG (ex: 5.0)
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
# PIXART-α Controlnet pipeline
[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
## Example Usage
This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
```py
import sys
import os
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
from diffusers.utils import load_image
from diffusers.image_processor import PixArtImageProcessor
from controlnet_aux import HEDdetector
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split('.')[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.CenterCrop([image_size, image_size]),
])
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
```
In the folder examples/pixart there is also a script that can be used to train new models.
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
File diff suppressed because it is too large Load Diff
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
class MarigoldDepthOutput(BaseOutput):
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
+3 -4
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -152,7 +152,6 @@ def log_validation(
guidance_scale=3.5,
generator=generator,
).images[0]
image = image.resize((args.resolution, args.resolution))
images.append(image)
image_logs.append(
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
@@ -1257,8 +1256,8 @@ def main(args):
latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
batch_size=pixel_latents_tmp.shape[0],
height=pixel_latents_tmp.shape[2] // 2,
width=pixel_latents_tmp.shape[3] // 2,
height=pixel_latents_tmp.shape[2],
width=pixel_latents_tmp.shape[3],
device=pixel_values.device,
dtype=pixel_values.dtype,
)
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
-15
View File
@@ -170,21 +170,6 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Text Encoder Training
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
-34
View File
@@ -147,40 +147,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```
### Targeting Specific Blocks & Layers
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
transformer blocks (sometimes as little as two) can be enough to get great results.
In some cases, it can be even better to maintain some of the blocks/layers frozen.
For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
> [!NOTE]
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
> So, freezing other layers/targeting specific layers is a viable approach.
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
> **Photorealism**
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
> **Anatomy preservation**
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
```diff
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
```
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
- with `--lora_layers` you can specify the types of layers you wish to train.
By default, the trained layers are -
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
```diff
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
```
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
freezing some of the early & blocks is usually better.
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
@@ -37,7 +37,6 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
@@ -137,43 +136,6 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
@@ -38,9 +38,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
transformer_block_idx = 0
layer_type = "attn.to_k"
def test_dreambooth_lora_sd3(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
@@ -139,74 +136,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_block(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_blocks {self.transformer_block_idx}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
# In this test, only params of transformer block 0 should be in the state dict
starts_with_transformer = all(
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layer(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_layers {self.layer_type}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+16 -37
View File
@@ -57,7 +57,6 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -65,16 +64,10 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
torch.npu.set_compile_mode(jit_compile=False)
def save_model_card(
repo_id: str,
@@ -168,7 +161,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -196,8 +189,6 @@ def log_validation(
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
return images
@@ -1044,9 +1035,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
has_supported_fp16_accelerator = (
torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
)
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
@@ -1084,8 +1073,6 @@ def main(args):
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1239,7 +1226,10 @@ def main(args):
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr]
@@ -1298,9 +1288,11 @@ def main(args):
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1367,8 +1359,6 @@ def main(args):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1550,12 +1540,12 @@ def main(args):
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
model_input.shape[2],
model_input.shape[3],
accelerator.device,
weight_dtype,
)
@@ -1590,7 +1580,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if transformer.config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1611,8 +1601,8 @@ def main(args):
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
@@ -1704,8 +1694,6 @@ def main(args):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
else: # even when training the text encoder we're only training text encoder one
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1734,15 +1722,9 @@ def main(args):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
torch.cuda.empty_cache()
gc.collect()
images = None
del pipeline
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -1801,9 +1783,6 @@ def main(args):
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
accelerator.end_training()
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -177,7 +177,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -554,15 +554,6 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -1195,30 +1186,12 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
]
# now we will add new LoRA weights the transformer layers
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=target_modules,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
@@ -1335,7 +1308,10 @@ def main(args):
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr]
@@ -1391,12 +1367,14 @@ def main(args):
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one to be
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1667,12 +1645,12 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
model_input.shape[2],
model_input.shape[3],
accelerator.device,
weight_dtype,
)
@@ -1706,7 +1684,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if transformer.config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1726,8 +1704,8 @@ def main(args):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
@@ -1819,8 +1797,6 @@ def main(args):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1844,9 +1820,6 @@ def main(args):
del text_encoder_one, text_encoder_two
free_memory()
images = None
del pipeline
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -1911,9 +1884,6 @@ def main(args):
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
accelerator.end_training()
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -86,15 +86,6 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -104,7 +95,7 @@ def save_model_card(
)
model_description = f"""
# {model_variant} DreamBooth LoRA - {repo_id}
# SD3 DreamBooth LoRA - {repo_id}
<Gallery />
@@ -129,7 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation.
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
@@ -144,7 +135,7 @@ For more details, including weighting, merging and fusing LoRAs, check the [docu
## License
Please adhere to the licensing terms as described [here]({license_url}).
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -160,11 +151,11 @@ Please adhere to the licensing terms as described [here]({license_url}).
"diffusers-training",
"diffusers",
"lora",
"sd3",
"sd3-diffusers",
"template:sd-lora",
]
tags += variant_tags
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
@@ -571,25 +562,6 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
),
)
parser.add_argument(
"--lora_blocks",
type=str,
default=None,
help=(
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -1241,31 +1213,13 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_out.0",
"attn.to_q",
"attn.to_v",
]
if args.lora_blocks is not None:
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
target_modules = [
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
]
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=target_modules,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
@@ -1468,6 +1422,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -1402,6 +1402,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
+15 -30
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -77,15 +77,6 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -95,7 +86,7 @@ def save_model_card(
)
model_description = f"""
# {model_variant} DreamBooth - {repo_id}
# SD3 DreamBooth - {repo_id}
<Gallery />
@@ -122,7 +113,7 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
## License
Please adhere to the licensing terms as described `[here]({license_url})`.
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -137,9 +128,10 @@ Please adhere to the licensing terms as described `[here]({license_url})`.
"text-to-image",
"diffusers-training",
"diffusers",
"sd3",
"sd3-diffusers",
"template:sd-lora",
]
tags += variant_tags
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
@@ -902,26 +894,20 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
@@ -943,7 +929,6 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -952,14 +937,13 @@ def encode_prompt(
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
@@ -1336,6 +1320,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -10
View File
@@ -1,13 +1,4 @@
## Diffusion-based Policy Learning for RL
`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.
This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.
To execute the script, run `diffusion_policy.py`
## Diffuser Locomotion
# Overview
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
There are two ways to use the script, `run_diffuser_locomotion.py`.
@@ -1,201 +0,0 @@
import numpy as np
import numpy.core.multiarray as multiarray
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torch.serialization import add_safe_globals
from diffusers import DDPMScheduler, UNet1DModel
add_safe_globals(
[
multiarray._reconstruct,
np.ndarray,
np.dtype,
np.dtype(np.float32).type,
np.dtype(np.float64).type,
np.dtype(np.int32).type,
np.dtype(np.int64).type,
type(np.dtype(np.float32)),
type(np.dtype(np.float64)),
type(np.dtype(np.int32)),
type(np.dtype(np.int64)),
]
)
"""
An example of using HuggingFace's diffusers library for diffusion policy,
generating smooth movement trajectories.
This implements a robot control model for pushing a T-shaped block into a target area.
The model takes in the robot arm position, block position, and block angle,
then outputs a sequence of 16 (x,y) positions for the robot arm to follow.
"""
class ObservationEncoder(nn.Module):
"""
Converts raw robot observations (positions/angles) into a more compact representation
state_dim (int): Dimension of the input state vector (default: 5)
[robot_x, robot_y, block_x, block_y, block_angle]
- Input shape: (batch_size, state_dim)
- Output shape: (batch_size, 256)
"""
def __init__(self, state_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256))
def forward(self, x):
return self.net(x)
class ObservationProjection(nn.Module):
"""
Takes the encoded observation and transforms it into 32 values that represent the current robot/block situation.
These values are used as additional contextual information during the diffusion model's trajectory generation.
- Input: 256-dim vector (padded to 512)
Shape: (batch_size, 256)
- Output: 32 contextual information values for the diffusion model
Shape: (batch_size, 32)
"""
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(32, 512))
self.bias = nn.Parameter(torch.zeros(32))
def forward(self, x): # pad 256-dim input to 512-dim with zeros
if x.size(-1) == 256:
x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)
return nn.functional.linear(x, self.weight, self.bias)
class DiffusionPolicy:
"""
Implements diffusion policy for generating robot arm trajectories.
Uses diffusion to generate sequences of positions for a robot arm, conditioned on
the current state of the robot and the block it needs to push.
The model expects observations in pixel coordinates (0-512 range) and block angle in radians.
It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.
"""
def __init__(self, state_dim=5, device="cpu"):
self.device = device
# define valid ranges for inputs/outputs
self.stats = {
"obs": {"min": torch.zeros(5), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi])},
"action": {"min": torch.zeros(2), "max": torch.full((2,), 512)},
}
self.obs_encoder = ObservationEncoder(state_dim).to(device)
self.obs_projection = ObservationProjection().to(device)
# UNet model that performs the denoising process
# takes in concatenated action (2 channels) and context (32 channels) = 34 channels
# outputs predicted action (2 channels for x,y coordinates)
self.model = UNet1DModel(
sample_size=16, # length of trajectory sequence
in_channels=34,
out_channels=2,
layers_per_block=2, # number of layers per each UNet block
block_out_channels=(128,), # number of output neurons per layer in each block
down_block_types=("DownBlock1D",), # reduce the resolution of data
up_block_types=("UpBlock1D",), # increase the resolution of data
).to(device)
# noise scheduler that controls the denoising process
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=100, # number of denoising steps
beta_schedule="squaredcos_cap_v2", # type of noise schedule
)
# load pre-trained weights from HuggingFace
checkpoint = torch.load(
hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device
)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.obs_encoder.load_state_dict(checkpoint["encoder_state_dict"])
self.obs_projection.load_state_dict(checkpoint["projection_state_dict"])
# scales data to [-1, 1] range for neural network processing
def normalize_data(self, data, stats):
return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1
# converts normalized data back to original range
def unnormalize_data(self, ndata, stats):
return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"]
@torch.no_grad()
def predict(self, observation):
"""
Generates a trajectory of robot arm positions given the current state.
Args:
observation (torch.Tensor): Current state [robot_x, robot_y, block_x, block_y, block_angle]
Shape: (batch_size, 5)
Returns:
torch.Tensor: Sequence of (x,y) positions for the robot arm to follow
Shape: (batch_size, 16, 2) where:
- 16 is the number of steps in the trajectory
- 2 is the (x,y) coordinates in pixel space (0-512)
The function first encodes the observation, then uses it to condition a diffusion
process that gradually denoises random trajectories into smooth, purposeful movements.
"""
observation = observation.to(self.device)
normalized_obs = self.normalize_data(observation, self.stats["obs"])
# encode the observation into context values for the diffusion model
cond = self.obs_projection(self.obs_encoder(normalized_obs))
# keeps first & second dimension sizes unchanged, and multiplies last dimension by 16
cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16)
# initialize action with noise - random noise that will be refined into a trajectory
action = torch.randn((observation.shape[0], 2, 16), device=self.device)
# denoise
# at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are
# fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to
# apply this prediction & slightly reduce the noise in `action` more
self.noise_scheduler.set_timesteps(100)
for t in self.noise_scheduler.timesteps:
model_output = self.model(torch.cat([action, cond], dim=1), t)
action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample
action = action.transpose(1, 2) # reshape to [batch, 16, 2]
action = self.unnormalize_data(action, self.stats["action"]) # scale back to coordinates
return action
if __name__ == "__main__":
policy = DiffusionPolicy()
# sample of a single observation
# robot arm starts in center, block is slightly left and up, rotated 90 degrees
obs = torch.tensor(
[
[
256.0, # robot arm x position (middle of screen)
256.0, # robot arm y position (middle of screen)
200.0, # block x position
300.0, # block y position
np.pi / 2, # block angle (90 degrees)
]
]
)
action = policy.predict(obs)
print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions
print("\nPredicted trajectory:")
for i, (x, y) in enumerate(action[0]):
print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
@@ -1,167 +0,0 @@
## LoRA fine-tuning Flux.1 Dev with quantization
> [!NOTE]
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
* Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print.
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
* Since we already precomputed the text embeddings, we don't load the text encoders.
* We load the VAE and use it to precompute the image latents and we then delete it.
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
* Train!
To run training in a memory-optimized manner, we additionally use:
* 8Bit Adam
* Gradient checkpointing
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
## Training
Ensure you have installed the required libraries:
```bash
pip install -U transformers accelerate bitsandbytes peft datasets
pip install git+https://github.com/huggingface/diffusers -U
```
Now, compute the text embeddings:
```bash
python compute_embeddings.py
```
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
```bash
huggingface-cli
```
Then launch:
```bash
accelerate launch --config_file=accelerate.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="fp16" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
```
We can direcly pass a quantized checkpoint path, too:
```diff
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
```
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
```bash
pip install -Uq deepspeed
```
And then launch:
```bash
accelerate launch --config_file=ds2.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="no" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
```
## Inference
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
1. First, load the original model and merge the LoRA params into it:
```py
from diffusers import FluxPipeline
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline = FluxPipeline.from_pretrained(
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
)
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.transformer.save_pretrained("fused_transformer")
```
2. Quantize the model and run inference
```py
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
bnb_4bit_compute_dtype = torch.float16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)
transformer = FluxTransformer2DModel.from_pretrained(
"fused_transformer",
quantization_config=nf4_config,
torch_dtype=bnb_4bit_compute_dtype,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
)
pipeline.enable_model_cpu_offload()
image = pipeline(
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
).images[0]
image.save("yarn_merged.png")
```
| Dequantize, merge, quantize | Merging directly into quantized model |
|-------|-------|
| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) |
As we can notice the first column result follows the style more closely.
@@ -1,17 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: NO
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
@@ -1,107 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import pandas as pd
import torch
from datasets import load_dataset
from huggingface_hub.utils import insecure_hashlib
from tqdm.auto import tqdm
from transformers import T5EncoderModel
from diffusers import FluxPipeline
MAX_SEQ_LENGTH = 77
OUTPUT_PATH = "embeddings.parquet"
def generate_image_hash(image):
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
def load_flux_dev_pipeline():
id = "black-forest-labs/FLUX.1-dev"
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
pipeline = FluxPipeline.from_pretrained(
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
)
return pipeline
@torch.no_grad()
def compute_embeddings(pipeline, prompts, max_sequence_length):
all_prompt_embeds = []
all_pooled_prompt_embeds = []
all_text_ids = []
for prompt in tqdm(prompts, desc="Encoding prompts."):
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
all_prompt_embeds.append(prompt_embeds)
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
all_text_ids.append(text_ids)
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f"Max memory allocated: {max_memory:.3f} GB")
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
def run(args):
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
all_prompts = list(image_prompts.values())
print(f"{len(all_prompts)=}")
pipeline = load_flux_dev_pipeline()
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
pipeline, all_prompts, args.max_sequence_length
)
data = []
for i, (image_hash, _) in enumerate(image_prompts.items()):
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
print(f"{len(data)=}")
# Create a DataFrame
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
print(f"{len(df)=}")
# Convert embedding lists to arrays (for proper storage in parquet)
for col in embedding_cols:
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
# Save the dataframe to a parquet file
df.to_parquet(args.output_path)
print(f"Data successfully serialized to {args.output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_sequence_length",
type=int,
default=MAX_SEQ_LENGTH,
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
)
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
args = parser.parse_args()
run(args)
@@ -1,23 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
@@ -1,2 +0,0 @@
images/
output/
@@ -1,307 +0,0 @@
from typing import Any, Dict, Optional
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
def __init__(
self,
block_index,
# taken from PixArtTransformer2DModel
num_attention_heads: int = 16,
attention_head_dim: int = 72,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = 1152,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_type: Optional[str] = "default",
):
super().__init__()
self.block_index = block_index
self.inner_dim = num_attention_heads * attention_head_dim
# the first block has a zero before layer
if self.block_index == 0:
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.transformer_block = BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def train(self, mode: bool = True):
self.transformer_block.train(mode)
if self.block_index == 0:
self.before_proj.train(mode)
self.after_proj.train(mode)
def forward(
self,
hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
if self.block_index == 0:
controlnet_states = self.before_proj(controlnet_states)
controlnet_states = hidden_states + controlnet_states
controlnet_states_down = self.transformer_block(
hidden_states=controlnet_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
class_labels=None,
)
controlnet_states_left = self.after_proj(controlnet_states_down)
return controlnet_states_left, controlnet_states_down
class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
@register_to_config
def __init__(self, num_layers=13) -> None:
super().__init__()
self.num_layers = num_layers
self.controlnet_blocks = nn.ModuleList(
[PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
)
@classmethod
def from_transformer(cls, transformer: PixArtTransformer2DModel):
control_net = PixArtControlNetAdapterModel()
# copied the specified number of blocks from the transformer
for depth in range(control_net.num_layers):
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
transformer.transformer_blocks[depth].state_dict()
)
return control_net
def train(self, mode: bool = True):
for block in self.controlnet_blocks:
block.train(mode)
class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
def __init__(
self,
transformer: PixArtTransformer2DModel,
controlnet: PixArtControlNetAdapterModel,
blocks_num=13,
init_from_transformer=False,
training=False,
):
super().__init__()
self.blocks_num = blocks_num
self.gradient_checkpointing = False
self.register_to_config(**transformer.config)
self.training = training
if init_from_transformer:
# copies the specified number of blocks from the transformer
controlnet.from_transformer(transformer, self.blocks_num)
self.transformer = transformer
self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
controlnet_cond: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.transformer.config.patch_size,
hidden_states.shape[-1] // self.transformer.config.patch_size,
)
hidden_states = self.transformer.pos_embed(hidden_states)
timestep, embedded_timestep = self.transformer.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.transformer.caption_projection is not None:
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
controlnet_states_down = None
if controlnet_cond is not None:
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
block_index - 1
](
hidden_states=hidden_states, # used only in the first block
controlnet_states=controlnet_states_down,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + controlnet_states_left
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
# 3. Output
shift, scale = (
self.transformer.scale_shift_table[None]
+ embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.transformer.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.transformer.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.transformer.config.patch_size,
self.transformer.config.patch_size,
self.transformer.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.transformer.out_channels,
height * self.transformer.config.patch_size,
width * self.transformer.config.patch_size,
)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
File diff suppressed because it is too large Load Diff
@@ -1,6 +0,0 @@
transformers
SentencePiece
torchvision
controlnet-aux
datasets
# wandb
@@ -1,75 +0,0 @@
import torch
import torchvision.transforms as T
from controlnet_aux import HEDdetector
from diffusers.utils import load_image
from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "yellow modern car, city in background, beautiful rainy day"
# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "purple car, on highway, beautiful sunny day"
# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ."
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split(".")[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB")),
T.CenterCrop([image_size, image_size]),
]
)
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
@@ -1,23 +0,0 @@
#!/bin/bash
# run
# accelerate config
# check with
# accelerate env
export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512"
export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test"
accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--resolution=512 \
--learning_rate=1e-5 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
--dataloader_num_workers=8
# --lr_scheduler="cosine" --lr_warmup_steps=0 \
File diff suppressed because it is too large Load Diff
@@ -229,11 +229,11 @@ class PromptDiffusionControlNetModel(ControlNetModel):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
@@ -1475,6 +1475,7 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
-461
View File
@@ -1,461 +0,0 @@
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 256
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
# This is specific to `AdaLayerNormContinuous`:
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def swap_proj_gate(weight):
proj, gate = weight.chunk(2, dim=0)
new_weight = torch.cat([gate, proj], dim=0)
return new_weight
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = load_file(ckpt_path, device="cpu")
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
old_prefix + "attn.proj_y.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
original_state_dict.pop(old_prefix + "mlp_y.w1.weight")
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp_y.w2.weight"
)
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.mod.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
print("Remaining Keys:", original_state_dict.keys())
return new_state_dict
def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
new_state_dict = {}
# ==== Decoder =====
prefix = "decoder."
# Convert conv_in
new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight")
new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias")
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.proj.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.4.{i}.stack.5.bias"
)
# Convert proj_out (Conv1x1 ~= nn.Linear)
new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight")
new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias")
print("Remaining Decoder Keys:", decoder_state_dict.keys())
# ==== Encoder =====
prefix = "encoder."
new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight")
new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias")
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
)
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
)
# Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.bias"
)
# Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.bias"
)
# Convert output layers
new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight")
new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias")
new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight")
print("Remaining Encoder Keys:", encoder_state_dict.keys())
return new_state_dict
def main(args):
if args.dtype is None:
dtype = None
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = MochiTransformer3DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
transformer = transformer.to(dtype=dtype)
if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
pipe = MochiPipeline(
scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
main(args)
+1 -1
View File
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+3 -15
View File
@@ -1,4 +1,4 @@
__version__ = "0.32.0.dev0"
__version__ = "0.31.0.dev0"
from typing import TYPE_CHECKING
@@ -77,13 +77,10 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
@@ -103,7 +100,6 @@ else:
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -241,7 +237,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
@@ -313,7 +308,6 @@ else:
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
@@ -487,7 +481,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
@@ -562,13 +556,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
@@ -588,7 +579,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
@@ -707,7 +697,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AmusedImg2ImgPipeline,
@@ -777,7 +766,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
PIAPipeline,
@@ -914,7 +902,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .models.controlnets.controlnet_flax import FlaxControlNetModel
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
+3 -56
View File
@@ -97,17 +97,13 @@ class SDCFGCutoffCallback(PipelineCallback):
class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
]
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
@@ -133,55 +129,6 @@ class SDXLCFGCutoffCallback(PipelineCallback):
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs
class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"""
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"image",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
# For Controlnet
image = callback_kwargs[self.tensor_inputs[3]]
image = image[-1:]
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
callback_kwargs[self.tensor_inputs[3]] = image
return callback_kwargs
+6 -8
View File
@@ -33,14 +33,16 @@ from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)
logger = logging.get_logger(__name__)
@@ -282,9 +284,7 @@ class IPAdapterMixin:
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
@@ -342,9 +342,7 @@ class IPAdapterMixin:
)
attn_procs[name] = (
attn_processor_class
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
+118 -124
View File
@@ -51,9 +51,6 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
"""
@@ -184,119 +181,6 @@ def _remove_text_encoder_monkey_patch(text_encoder):
text_encoder._hf_peft_config_loaded = None
def _fetch_state_dict(
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
def _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
@@ -350,16 +234,124 @@ class LoraBaseMixin:
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
def unload_lora_weights(self):
"""
@@ -733,6 +725,8 @@ class LoraBaseMixin:
save_function: Callable,
safe_serialization: bool,
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
+371 -127
View File
@@ -21,6 +21,7 @@ from ..utils import (
USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
@@ -32,7 +33,7 @@ from ..utils import (
logging,
scale_lora_layers,
)
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_base import LoraBaseMixin
from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
@@ -61,6 +62,9 @@ TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
@@ -218,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -278,9 +282,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -339,9 +341,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -601,9 +601,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -746,7 +744,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -807,9 +805,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -869,9 +865,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1188,7 +1182,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1232,9 +1226,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -1258,17 +1250,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
if len(transformer_state_dict) > 0:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -1313,24 +1301,94 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1366,9 +1424,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1686,7 +1742,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1763,9 +1819,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1789,18 +1843,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
if len(transformer_state_dict) > 0:
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -1831,32 +1881,104 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`FluxTransformer2DModel`):
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1892,9 +2014,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -2122,10 +2242,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2138,32 +2255,93 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`UVit2DModel`):
The Transformer model to load the LoRA layers into.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
# Load the layers corresponding to transformer.
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
network_alphas = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if len(state_dict.keys()) > 0:
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2199,9 +2377,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -2443,7 +2619,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -2482,9 +2658,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -2517,7 +2691,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
@@ -2529,29 +2703,99 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`CogVideoXTransformer3DModel`):
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
-223
View File
@@ -16,32 +16,18 @@ import inspect
from functools import partial
from typing import Dict, List, Optional, Union
import torch.nn as nn
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_base import _fetch_state_dict
from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
@@ -67,215 +53,6 @@ class PeftAdapterMixin:
_hf_peft_config_loaded = False
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
Loads a LoRA adapter into the underlying model.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
prefix (`str`, *optional*): Prefix to filter the state dict.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
adapter_name = kwargs.pop("adapter_name", None)
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(prefix)]
if len(transformer_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def set_adapters(
self,
adapter_names: Union[List[str], str],
+4 -9
View File
@@ -765,7 +765,6 @@ class UNet2DConditionLoadersMixin:
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)
if low_cpu_mem_usage:
@@ -805,15 +804,11 @@ class UNet2DConditionLoadersMixin:
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
if "XFormers" in str(self.attn_processors[name].__class__):
attn_processor_class = IPAdapterXFormersAttnProcessor
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
+14 -33
View File
@@ -28,24 +28,18 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -60,10 +54,8 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -78,7 +70,7 @@ if is_torch_available():
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
if is_flax_available():
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
@@ -89,32 +81,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
)
from .controlnets import (
ControlNetModel,
ControlNetXSAdapter,
FluxControlNetModel,
FluxMultiControlNetModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SparseControlNetModel,
UNetControlNetXSModel,
)
from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
@@ -124,7 +106,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
@@ -148,7 +129,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
if is_flax_available():
from .controlnets import FlaxControlNetModel
from .controlnet_flax import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
-1
View File
@@ -136,7 +136,6 @@ class SwiGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
+3 -743
View File
@@ -120,16 +120,14 @@ class Attention(nn.Module):
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
):
super().__init__()
# To prevent circular import.
from .normalization import FP32LayerNorm, LpNorm, RMSNorm
from .normalization import FP32LayerNorm, RMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -144,10 +142,8 @@ class Attention(nn.Module):
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -196,9 +192,6 @@ class Attention(nn.Module):
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "l2":
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
@@ -248,7 +241,7 @@ class Attention(nn.Module):
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
@@ -318,10 +311,7 @@ class Attention(nn.Module):
XFormersAttnAddedKVProcessor,
),
)
is_ip_adapter = hasattr(self, "processor") and isinstance(
self.processor,
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError(
@@ -371,19 +361,6 @@ class Attention(nn.Module):
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
elif is_ip_adapter:
processor = IPAdapterXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
scale=self.processor.scale,
num_tokens=self.processor.num_tokens,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_ip"):
processor.to(
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
@@ -1544,100 +1521,6 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None and not attn.is_cross_attention:
from .embeddings import apply_rotary_emb_allegro
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
@@ -1899,113 +1782,6 @@ class FluxAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
@@ -2117,117 +1893,6 @@ class FusedFluxAttnProcessor2_0:
return hidden_states
class FusedFluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -2738,91 +2403,6 @@ class AttnProcessor2_0:
return hidden_states
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
is_single_frame = hidden_states.shape[1] == 1
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if is_single_frame:
hidden_states = attn.to_v(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
@@ -3498,94 +3078,6 @@ class LuminaAttnProcessor2_0:
return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -4558,238 +4050,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
r"""
Attention processor for IP-Adapter using xFormers.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(
self,
hidden_size,
cross_attention_dim=None,
num_tokens=(4,),
scale=1.0,
attention_op: Optional[Callable] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_hidden_states:
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(
zip(ip_adapter_masks, self.scale, ip_hidden_states)
):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
mask = mask.to(torch.float16)
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -1,8 +1,6 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
File diff suppressed because it is too large Load Diff
@@ -94,13 +94,11 @@ class CogVideoXCausalConv3d(nn.Module):
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
# TODO(aryan): configure calculation based on stride and dilation in the future.
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
time_pad = time_kernel_size - 1
height_pad = (height_kernel_size - 1) // 2
width_pad = (width_kernel_size - 1) // 2
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
@@ -109,7 +107,7 @@ class CogVideoXCausalConv3d(nn.Module):
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d(
in_channels=in_channels,
@@ -122,24 +120,18 @@ class CogVideoXCausalConv3d(nn.Module):
def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.pad_mode == "replicate":
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
else:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
if self.pad_mode == "replicate":
conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output, conv_cache
File diff suppressed because it is too large Load Diff
+850 -22
View File
@@ -11,32 +11,860 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import deprecate
from .controlnets.controlnet import ( # noqa
BaseOutput,
ControlNetConditioningEmbedding,
ControlNetModel,
ControlNetOutput,
zero_module,
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders.single_file_model import FromOriginalModelMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unets.unet_2d_condition import UNet2DConditionModel
class ControlNetOutput(ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
deprecate("ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ControlNetModel(ControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
deprecate("ControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
@dataclass
class ControlNetOutput(BaseOutput):
"""
The output of [`ControlNetModel`].
Args:
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 latent images for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 3,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
mid_block_type=unet.config.mid_block_type,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
if hasattr(controlnet, "add_embedding"):
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.Tensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.addition_embed_type is not None:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
@@ -19,11 +19,11 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from .unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
+512 -17
View File
@@ -12,30 +12,525 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from ..utils import deprecate, logging
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin
from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxControlNetOutput(FluxControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
deprecate("FluxControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
@dataclass
class FluxControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class FluxControlNetModel(FluxControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
deprecate("FluxControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return FluxControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
class FluxMultiControlNetModel(FluxMultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class FluxMultiControlNetModel(ModelMixin):
r"""
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
compatible with `FluxControlNetModel`.
Args:
controlnets (`List[FluxControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`FluxControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[FluxControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
return control_block_samples, control_single_block_samples
@@ -17,17 +17,17 @@ from typing import Dict, Optional, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention_processor import AttentionProcessor
from ..embeddings import (
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .attention_processor import AttentionProcessor
from .controlnet import BaseOutput, Tuple, zero_module
from .embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
from ..modeling_utils import ModelMixin
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
from .controlnet import BaseOutput, Tuple, zero_module
from .modeling_utils import ModelMixin
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Some files were not shown because too many files have changed in this diff Show More