Compare commits
42 Commits
ip-11
..
test-mochi
| Author | SHA1 | Date | |
|---|---|---|---|
| 1069d210e1 | |||
| 303b47cc53 | |||
| 59258445aa | |||
| 5d093fed7f | |||
| cae2801bfb | |||
| 72741ec503 | |||
| 237e079e5e | |||
| fb2ede03eb | |||
| 8e11f34a74 | |||
| 5f43c6a41f | |||
| 2798ed4c7a | |||
| d41198cbdf | |||
| c916ae59ea | |||
| ba9f13fa90 | |||
| 723329d1f2 | |||
| 3e569cb644 | |||
| a7372bd9b3 | |||
| 6552653f11 | |||
| 7c55ef5000 | |||
| 0b76fea5dd | |||
| e9e92d043d | |||
| 0a6189eb95 | |||
| 969c3aba88 | |||
| 8700d64d62 | |||
| 85a9825449 | |||
| ebcbad2f38 | |||
| 44987ad98c | |||
| c12ce7d32b | |||
| 275041d21e | |||
| 46f95d5cdb | |||
| 2fd2ec4025 | |||
| ccc1b36b09 | |||
| 85c8734cdc | |||
| 98a4554ac6 | |||
| 1e9bc91b5c | |||
| be5bbe53e1 | |||
| c2a155714b | |||
| 0e9e281ad1 | |||
| 05ebd6cd82 | |||
| da48940b56 | |||
| 64275b0e66 | |||
| e488d09df1 |
@@ -180,62 +180,6 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_big_gpu_torch_tests:
|
||||
name: Torch tests on big GPU
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Selected Torch CUDA Test on big GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-m "big_gpu_with_torch_cuda" \
|
||||
--make-reports=tests_big_gpu_torch_cuda \
|
||||
--report-log=tests_big_gpu_torch_cuda.log \
|
||||
tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_big_gpu_torch_cuda_stats.txt
|
||||
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_big_gpu_test_reports
|
||||
path: reports
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
|
||||
@@ -4,13 +4,12 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
runner_type:
|
||||
description: 'Type of runner to test (aws-g6-4xlarge-plus: a10, aws-g4dn-2xlarge: t4, aws-g6e-xlarge-plus: L40)'
|
||||
description: 'Type of runner to test (aws-g6-4xlarge-plus: a10 or aws-g4dn-2xlarge: t4)'
|
||||
type: choice
|
||||
required: true
|
||||
options:
|
||||
- aws-g6-4xlarge-plus
|
||||
- aws-g4dn-2xlarge
|
||||
- aws-g6e-xlarge-plus
|
||||
docker_image:
|
||||
description: 'Name of the Docker image'
|
||||
required: true
|
||||
|
||||
@@ -252,8 +252,6 @@
|
||||
title: SparseControlNetModel
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/allegro_transformer3d
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
@@ -304,8 +302,6 @@
|
||||
- sections:
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_allegro
|
||||
title: AutoencoderKLAllegro
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
@@ -326,8 +322,6 @@
|
||||
sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AllegroTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AllegroTransformer3DModel
|
||||
|
||||
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## AllegroTransformer3DModel
|
||||
|
||||
[[autodoc]] AllegroTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -1,37 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLAllegro
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLAllegro
|
||||
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLAllegro
|
||||
|
||||
[[autodoc]] AutoencoderKLAllegro
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
|
||||
|
||||
## ControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
|
||||
[[autodoc]] models.controlnet.ControlNetOutput
|
||||
|
||||
## FlaxControlNetModel
|
||||
|
||||
@@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
|
||||
|
||||
## FlaxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
|
||||
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
|
||||
|
||||
## SD3ControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
|
||||
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Allegro
|
||||
|
||||
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AllegroPipeline
|
||||
|
||||
[[autodoc]] AllegroPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AllegroPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput
|
||||
@@ -183,7 +183,7 @@ Add the transformer model to the pipeline for denoising, but set the other model
|
||||
|
||||
```py
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
"black-forest-labs/FLUX.1-dev", ,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer=None,
|
||||
|
||||
@@ -1650,8 +1650,6 @@ def main(args):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1778,10 +1776,15 @@ def main(args):
|
||||
if not args.enable_t5_ti:
|
||||
# pure textual inversion - only clip
|
||||
if pure_textual_inversion:
|
||||
params_to_optimize = [text_parameters_one_with_lr]
|
||||
params_to_optimize = [
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
te_idx = 0
|
||||
else: # regular te training or regular pivotal for clip
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
te_idx = 1
|
||||
elif args.enable_t5_ti:
|
||||
# pivotal tuning of clip & t5
|
||||
@@ -1804,7 +1807,9 @@ def main(args):
|
||||
]
|
||||
te_idx = 1
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
@@ -1864,6 +1869,7 @@ def main(args):
|
||||
params_to_optimize[-1]["lr"] = args.learning_rate
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -67,7 +67,6 @@ from diffusers.utils import (
|
||||
convert_state_dict_to_kohya,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -80,27 +79,30 @@ logger = get_logger(__name__)
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
train_text_encoder_ti=False,
|
||||
token_abstraction_dict=None,
|
||||
instance_prompt=None,
|
||||
validation_prompt=None,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
lora = "lora" if not use_dora else "dora"
|
||||
|
||||
widget_dict = []
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
widget_dict.append(
|
||||
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
|
||||
)
|
||||
else:
|
||||
widget_dict.append({"text": instance_prompt})
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
if not images:
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
embeddings_filename = f"{repo_folder}_emb"
|
||||
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
||||
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
||||
@@ -135,7 +137,24 @@ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_en
|
||||
trigger_str += f"""
|
||||
to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
"""
|
||||
model_description = f"""
|
||||
|
||||
yaml = f"""---
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- diffusers-training
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- {lora}
|
||||
- template:sd-lora
|
||||
{img_str}
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
---
|
||||
"""
|
||||
|
||||
model_card = f"""
|
||||
# SD1.5 LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
@@ -183,28 +202,8 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
widget=widget_dict,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
lora,
|
||||
"template:sd-lora" "stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
@@ -1359,7 +1358,10 @@ def main(args):
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]
|
||||
params_to_optimize = [
|
||||
unet_lora_parameters_with_lr,
|
||||
text_lora_parameters_one_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [unet_lora_parameters_with_lr]
|
||||
|
||||
@@ -1421,6 +1423,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1794,6 +1794,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -947,6 +947,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -969,6 +969,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -10,7 +10,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
|
||||
@@ -74,7 +73,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
|
||||
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
|
||||
|
||||
@@ -86,161 +84,6 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
|
||||
|
||||
## Example usages
|
||||
|
||||
### Adaptive Mask Inpainting
|
||||
|
||||
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
|
||||
|
||||
**Seoul National University, Naver Webtoon**
|
||||
|
||||
Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.
|
||||
|
||||
Here is the demonstration of Adaptive Mask Inpainting:
|
||||
|
||||
<video controls>
|
||||
<source src="https://snuvclab.github.io/coma/static/videos/adaptive_mask_inpainting_vis.mp4" type="video/mp4">
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
|
||||

|
||||
|
||||
|
||||
You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma).
|
||||
|
||||
#### Usage example
|
||||
First, clone the diffusers github repository, and run the following command to set environment.
|
||||
```Shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
|
||||
conda create --name ami python=3.9 -y
|
||||
conda activate ami
|
||||
|
||||
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
|
||||
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
||||
pip install easydict
|
||||
pip install diffusers==0.20.2 accelerate safetensors transformers
|
||||
pip install setuptools==59.5.0
|
||||
pip install opencv-python
|
||||
pip install numpy==1.24.1
|
||||
```
|
||||
Then, run the below code under 'diffusers' directory.
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE
|
||||
|
||||
print(AMI_INSTALL_MESSAGE)
|
||||
|
||||
from easydict import EasyDict
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Download Necessary Files
|
||||
"""
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
|
||||
output_file = "model_final_edd263.pkl",
|
||||
exist_ok=True,
|
||||
)
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
|
||||
output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml",
|
||||
exist_ok=True,
|
||||
)
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
|
||||
output_file = "input_img.png",
|
||||
exist_ok=True,
|
||||
)
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
|
||||
output_file = "input_mask.png",
|
||||
exist_ok=True,
|
||||
)
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
|
||||
output_file = "Base-PointRend-RCNN-FPN.yaml",
|
||||
exist_ok=True,
|
||||
)
|
||||
download_file(
|
||||
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
|
||||
output_file = "Base-RCNN-FPN.yaml",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
"""
|
||||
Prepare Adaptive Mask Inpainting Pipeline
|
||||
"""
|
||||
# device
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
num_steps = 50
|
||||
|
||||
# Scheduler
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False
|
||||
)
|
||||
scheduler.set_timesteps(num_inference_steps=num_steps)
|
||||
|
||||
## load models as pipelines
|
||||
pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
|
||||
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch.float16,
|
||||
requires_safety_checker=False
|
||||
).to(device)
|
||||
|
||||
## disable safety checker
|
||||
enable_safety_checker = False
|
||||
if not enable_safety_checker:
|
||||
pipeline.safety_checker = None
|
||||
|
||||
"""
|
||||
Run Adaptive Mask Inpainting
|
||||
"""
|
||||
default_mask_image = Image.open("./input_mask.png").convert("L")
|
||||
init_image = Image.open("./input_img.png").convert("RGB")
|
||||
|
||||
|
||||
seed = 59
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
image = pipeline(
|
||||
prompt="a man sitting on a couch",
|
||||
negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
|
||||
image=init_image,
|
||||
default_mask_image=default_mask_image,
|
||||
guidance_scale=11.0,
|
||||
strength=0.98,
|
||||
use_adaptive_mask=True,
|
||||
generator=generator,
|
||||
enforce_full_mask_ratio=0.0,
|
||||
visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!!
|
||||
human_detection_thres=0.015,
|
||||
).images[0]
|
||||
|
||||
|
||||
image.save(f'final_img.png')
|
||||
```
|
||||
#### [Troubleshooting]
|
||||
|
||||
If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`.
|
||||
|
||||
For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.
|
||||
|
||||
|
||||
### Flux with CFG
|
||||
|
||||
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
|
||||
@@ -4602,94 +4445,3 @@ grid_image.save(grid_dir + "sample.png")
|
||||
`pag_scale` : guidance scale of PAG (ex: 5.0)
|
||||
|
||||
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
|
||||
|
||||
# PIXART-α Controlnet pipeline
|
||||
|
||||
[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
|
||||
|
||||
This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
|
||||
|
||||
## Example Usage
|
||||
|
||||
This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
|
||||
|
||||
```py
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from diffusers.image_processor import PixArtImageProcessor
|
||||
|
||||
from controlnet_aux import HEDdetector
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
|
||||
|
||||
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
|
||||
|
||||
weight_dtype = torch.float16
|
||||
image_size = 1024
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# load controlnet
|
||||
controlnet = PixArtControlNetAdapterModel.from_pretrained(
|
||||
controlnet_repo_id,
|
||||
torch_dtype=weight_dtype,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
|
||||
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=weight_dtype,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
|
||||
images_path = "images"
|
||||
control_image_file = "0_7.jpg"
|
||||
|
||||
prompt = "battleship in space, galaxy in background"
|
||||
|
||||
control_image_name = control_image_file.split('.')[0]
|
||||
|
||||
control_image = load_image(f"{images_path}/{control_image_file}")
|
||||
print(control_image.size)
|
||||
height, width = control_image.size
|
||||
|
||||
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
condition_transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB')),
|
||||
T.CenterCrop([image_size, image_size]),
|
||||
])
|
||||
|
||||
control_image = condition_transform(control_image)
|
||||
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
|
||||
|
||||
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
|
||||
|
||||
# run pipeline
|
||||
with torch.no_grad():
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
image=hed_edge,
|
||||
num_inference_steps=14,
|
||||
guidance_scale=4.5,
|
||||
height=image_size,
|
||||
width=image_size,
|
||||
)
|
||||
|
||||
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
|
||||
|
||||
```
|
||||
|
||||
In the folder examples/pixart there is also a script that can be used to train new models.
|
||||
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -152,7 +152,6 @@ def log_validation(
|
||||
guidance_scale=3.5,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image = image.resize((args.resolution, args.resolution))
|
||||
images.append(image)
|
||||
image_logs.append(
|
||||
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
|
||||
|
||||
@@ -170,21 +170,6 @@ accelerate launch train_dreambooth_lora_flux.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
### Text Encoder Training
|
||||
|
||||
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
|
||||
|
||||
@@ -147,40 +147,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Targeting Specific Blocks & Layers
|
||||
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
|
||||
transformer blocks (sometimes as little as two) can be enough to get great results.
|
||||
In some cases, it can be even better to maintain some of the blocks/layers frozen.
|
||||
|
||||
For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
|
||||
> [!NOTE]
|
||||
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
|
||||
> So, freezing other layers/targeting specific layers is a viable approach.
|
||||
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
|
||||
> **Photorealism**
|
||||
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
|
||||
> **Anatomy preservation**
|
||||
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
|
||||
|
||||
|
||||
We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
|
||||
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
|
||||
```diff
|
||||
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
|
||||
```
|
||||
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
|
||||
- with `--lora_layers` you can specify the types of layers you wish to train.
|
||||
By default, the trained layers are -
|
||||
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
|
||||
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
|
||||
```diff
|
||||
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
|
||||
```
|
||||
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
|
||||
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
|
||||
freezing some of the early & blocks is usually better.
|
||||
|
||||
|
||||
### Text Encoder Training
|
||||
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
|
||||
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
|
||||
|
||||
@@ -37,7 +37,6 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -137,43 +136,6 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -38,9 +38,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
|
||||
|
||||
transformer_block_idx = 0
|
||||
layer_type = "attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
@@ -139,74 +136,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_block(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_blocks {self.transformer_block_idx}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
# In this test, only params of transformer block 0 should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layer(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_layers {self.layer_type}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
|
||||
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -57,7 +57,6 @@ from diffusers.utils import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -69,12 +68,6 @@ check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
@@ -168,7 +161,7 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -196,8 +189,6 @@ def log_validation(
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
@@ -1044,9 +1035,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
has_supported_fp16_accelerator = (
|
||||
torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
|
||||
)
|
||||
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
|
||||
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
@@ -1084,8 +1073,6 @@ def main(args):
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1239,7 +1226,10 @@ def main(args):
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
@@ -1298,9 +1288,11 @@ def main(args):
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -1367,8 +1359,6 @@ def main(args):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1590,7 +1580,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1704,8 +1694,6 @@ def main(args):
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
else: # even when training the text encoder we're only training text encoder one
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -1734,15 +1722,9 @@ def main(args):
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
@@ -1801,9 +1783,6 @@ def main(args):
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -554,15 +554,6 @@ def parse_args(input_args=None):
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -1195,30 +1186,12 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = [
|
||||
"attn.to_k",
|
||||
"attn.to_q",
|
||||
"attn.to_v",
|
||||
"attn.to_out.0",
|
||||
"attn.add_k_proj",
|
||||
"attn.add_q_proj",
|
||||
"attn.add_v_proj",
|
||||
"attn.to_add_out",
|
||||
"ff.net.0.proj",
|
||||
"ff.net.2",
|
||||
"ff_context.net.0.proj",
|
||||
"ff_context.net.2",
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
if args.train_text_encoder:
|
||||
@@ -1335,7 +1308,10 @@ def main(args):
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
@@ -1391,12 +1367,14 @@ def main(args):
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# changes the learning rate of text_encoder_parameters_one to be
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -1706,7 +1684,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1819,8 +1797,6 @@ def main(args):
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1844,9 +1820,6 @@ def main(args):
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
@@ -1911,9 +1884,6 @@ def main(args):
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -571,25 +571,6 @@ def parse_args(input_args=None):
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
|
||||
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_blocks",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
|
||||
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -1241,31 +1222,13 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = [
|
||||
"attn.add_k_proj",
|
||||
"attn.add_q_proj",
|
||||
"attn.add_v_proj",
|
||||
"attn.to_add_out",
|
||||
"attn.to_k",
|
||||
"attn.to_out.0",
|
||||
"attn.to_q",
|
||||
"attn.to_v",
|
||||
]
|
||||
if args.lora_blocks is not None:
|
||||
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
|
||||
target_modules = [
|
||||
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
@@ -1468,6 +1431,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1402,6 +1402,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -902,26 +902,20 @@ def _encode_prompt_with_clip(
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
device=None,
|
||||
text_input_ids=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if tokenizer is not None:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
else:
|
||||
if text_input_ids is None:
|
||||
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
@@ -943,7 +937,6 @@ def encode_prompt(
|
||||
max_sequence_length,
|
||||
device=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -952,14 +945,13 @@ def encode_prompt(
|
||||
|
||||
clip_prompt_embeds_list = []
|
||||
clip_pooled_prompt_embeds_list = []
|
||||
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
|
||||
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
|
||||
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoder.device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
||||
)
|
||||
clip_prompt_embeds_list.append(prompt_embeds)
|
||||
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
||||
@@ -1336,6 +1328,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1,13 +1,4 @@
|
||||
|
||||
## Diffusion-based Policy Learning for RL
|
||||
|
||||
`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.
|
||||
|
||||
This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.
|
||||
|
||||
To execute the script, run `diffusion_policy.py`
|
||||
|
||||
## Diffuser Locomotion
|
||||
# Overview
|
||||
|
||||
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
|
||||
There are two ways to use the script, `run_diffuser_locomotion.py`.
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
import numpy as np
|
||||
import numpy.core.multiarray as multiarray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from torch.serialization import add_safe_globals
|
||||
|
||||
from diffusers import DDPMScheduler, UNet1DModel
|
||||
|
||||
|
||||
add_safe_globals(
|
||||
[
|
||||
multiarray._reconstruct,
|
||||
np.ndarray,
|
||||
np.dtype,
|
||||
np.dtype(np.float32).type,
|
||||
np.dtype(np.float64).type,
|
||||
np.dtype(np.int32).type,
|
||||
np.dtype(np.int64).type,
|
||||
type(np.dtype(np.float32)),
|
||||
type(np.dtype(np.float64)),
|
||||
type(np.dtype(np.int32)),
|
||||
type(np.dtype(np.int64)),
|
||||
]
|
||||
)
|
||||
|
||||
"""
|
||||
An example of using HuggingFace's diffusers library for diffusion policy,
|
||||
generating smooth movement trajectories.
|
||||
|
||||
This implements a robot control model for pushing a T-shaped block into a target area.
|
||||
The model takes in the robot arm position, block position, and block angle,
|
||||
then outputs a sequence of 16 (x,y) positions for the robot arm to follow.
|
||||
"""
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""
|
||||
Converts raw robot observations (positions/angles) into a more compact representation
|
||||
|
||||
state_dim (int): Dimension of the input state vector (default: 5)
|
||||
[robot_x, robot_y, block_x, block_y, block_angle]
|
||||
|
||||
- Input shape: (batch_size, state_dim)
|
||||
- Output shape: (batch_size, 256)
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class ObservationProjection(nn.Module):
|
||||
"""
|
||||
Takes the encoded observation and transforms it into 32 values that represent the current robot/block situation.
|
||||
These values are used as additional contextual information during the diffusion model's trajectory generation.
|
||||
|
||||
- Input: 256-dim vector (padded to 512)
|
||||
Shape: (batch_size, 256)
|
||||
- Output: 32 contextual information values for the diffusion model
|
||||
Shape: (batch_size, 32)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(32, 512))
|
||||
self.bias = nn.Parameter(torch.zeros(32))
|
||||
|
||||
def forward(self, x): # pad 256-dim input to 512-dim with zeros
|
||||
if x.size(-1) == 256:
|
||||
x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)
|
||||
return nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class DiffusionPolicy:
|
||||
"""
|
||||
Implements diffusion policy for generating robot arm trajectories.
|
||||
Uses diffusion to generate sequences of positions for a robot arm, conditioned on
|
||||
the current state of the robot and the block it needs to push.
|
||||
|
||||
The model expects observations in pixel coordinates (0-512 range) and block angle in radians.
|
||||
It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim=5, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
# define valid ranges for inputs/outputs
|
||||
self.stats = {
|
||||
"obs": {"min": torch.zeros(5), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi])},
|
||||
"action": {"min": torch.zeros(2), "max": torch.full((2,), 512)},
|
||||
}
|
||||
|
||||
self.obs_encoder = ObservationEncoder(state_dim).to(device)
|
||||
self.obs_projection = ObservationProjection().to(device)
|
||||
|
||||
# UNet model that performs the denoising process
|
||||
# takes in concatenated action (2 channels) and context (32 channels) = 34 channels
|
||||
# outputs predicted action (2 channels for x,y coordinates)
|
||||
self.model = UNet1DModel(
|
||||
sample_size=16, # length of trajectory sequence
|
||||
in_channels=34,
|
||||
out_channels=2,
|
||||
layers_per_block=2, # number of layers per each UNet block
|
||||
block_out_channels=(128,), # number of output neurons per layer in each block
|
||||
down_block_types=("DownBlock1D",), # reduce the resolution of data
|
||||
up_block_types=("UpBlock1D",), # increase the resolution of data
|
||||
).to(device)
|
||||
|
||||
# noise scheduler that controls the denoising process
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=100, # number of denoising steps
|
||||
beta_schedule="squaredcos_cap_v2", # type of noise schedule
|
||||
)
|
||||
|
||||
# load pre-trained weights from HuggingFace
|
||||
checkpoint = torch.load(
|
||||
hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device
|
||||
)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.obs_encoder.load_state_dict(checkpoint["encoder_state_dict"])
|
||||
self.obs_projection.load_state_dict(checkpoint["projection_state_dict"])
|
||||
|
||||
# scales data to [-1, 1] range for neural network processing
|
||||
def normalize_data(self, data, stats):
|
||||
return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1
|
||||
|
||||
# converts normalized data back to original range
|
||||
def unnormalize_data(self, ndata, stats):
|
||||
return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"]
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(self, observation):
|
||||
"""
|
||||
Generates a trajectory of robot arm positions given the current state.
|
||||
|
||||
Args:
|
||||
observation (torch.Tensor): Current state [robot_x, robot_y, block_x, block_y, block_angle]
|
||||
Shape: (batch_size, 5)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sequence of (x,y) positions for the robot arm to follow
|
||||
Shape: (batch_size, 16, 2) where:
|
||||
- 16 is the number of steps in the trajectory
|
||||
- 2 is the (x,y) coordinates in pixel space (0-512)
|
||||
|
||||
The function first encodes the observation, then uses it to condition a diffusion
|
||||
process that gradually denoises random trajectories into smooth, purposeful movements.
|
||||
"""
|
||||
observation = observation.to(self.device)
|
||||
normalized_obs = self.normalize_data(observation, self.stats["obs"])
|
||||
|
||||
# encode the observation into context values for the diffusion model
|
||||
cond = self.obs_projection(self.obs_encoder(normalized_obs))
|
||||
# keeps first & second dimension sizes unchanged, and multiplies last dimension by 16
|
||||
cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16)
|
||||
|
||||
# initialize action with noise - random noise that will be refined into a trajectory
|
||||
action = torch.randn((observation.shape[0], 2, 16), device=self.device)
|
||||
|
||||
# denoise
|
||||
# at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are
|
||||
# fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to
|
||||
# apply this prediction & slightly reduce the noise in `action` more
|
||||
|
||||
self.noise_scheduler.set_timesteps(100)
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
model_output = self.model(torch.cat([action, cond], dim=1), t)
|
||||
action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample
|
||||
|
||||
action = action.transpose(1, 2) # reshape to [batch, 16, 2]
|
||||
action = self.unnormalize_data(action, self.stats["action"]) # scale back to coordinates
|
||||
return action
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
policy = DiffusionPolicy()
|
||||
|
||||
# sample of a single observation
|
||||
# robot arm starts in center, block is slightly left and up, rotated 90 degrees
|
||||
obs = torch.tensor(
|
||||
[
|
||||
[
|
||||
256.0, # robot arm x position (middle of screen)
|
||||
256.0, # robot arm y position (middle of screen)
|
||||
200.0, # block x position
|
||||
300.0, # block y position
|
||||
np.pi / 2, # block angle (90 degrees)
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
action = policy.predict(obs)
|
||||
|
||||
print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions
|
||||
print("\nPredicted trajectory:")
|
||||
for i, (x, y) in enumerate(action[0]):
|
||||
print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
|
||||
@@ -5,8 +5,7 @@
|
||||
|
||||
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
|
||||
|
||||
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
|
||||
* Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print.
|
||||
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
|
||||
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
|
||||
* Since we already precomputed the text embeddings, we don't load the text encoders.
|
||||
* We load the VAE and use it to precompute the image latents and we then delete it.
|
||||
@@ -164,4 +163,4 @@ image.save("yarn_merged.png")
|
||||
|-------|-------|
|
||||
|  |  |
|
||||
|
||||
As we can notice the first column result follows the style more closely.
|
||||
As we can notice the first column result follows the style more closely.
|
||||
+15
-32
@@ -349,7 +349,7 @@ def parse_args(input_args=None):
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="AdamW",
|
||||
choices=["AdamW", "Prodigy", "AdEMAMix"],
|
||||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -357,11 +357,6 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_ademamix",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
||||
@@ -825,18 +820,19 @@ def main(args):
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
|
||||
logger.warning(
|
||||
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
@@ -857,20 +853,6 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
elif args.optimizer.lower() == "ademamix":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
|
||||
)
|
||||
if args.use_8bit_ademamix:
|
||||
optimizer_class = bnb.optim.AdEMAMix8bit
|
||||
else:
|
||||
optimizer_class = bnb.optim.AdEMAMix
|
||||
|
||||
optimizer = optimizer_class(params_to_optimize)
|
||||
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
@@ -886,6 +868,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -1037,12 +1020,12 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
model_input.shape[2] // 2,
|
||||
model_input.shape[3] // 2,
|
||||
model_input.shape[2],
|
||||
model_input.shape[3],
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
@@ -1076,7 +1059,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1099,8 +1082,8 @@ def main(args):
|
||||
)[0]
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
width=model_input.shape[3] * vae_scale_factor,
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
images/
|
||||
output/
|
||||
@@ -1,307 +0,0 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import PixArtTransformer2DModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils.torch_utils import is_torch_version
|
||||
|
||||
|
||||
class PixArtControlNetAdapterBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_index,
|
||||
# taken from PixArtTransformer2DModel
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 72,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = 1152,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_type: Optional[str] = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.block_index = block_index
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# the first block has a zero before layer
|
||||
if self.block_index == 0:
|
||||
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
nn.init.zeros_(self.before_proj.weight)
|
||||
nn.init.zeros_(self.before_proj.bias)
|
||||
|
||||
self.transformer_block = BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
|
||||
self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
nn.init.zeros_(self.after_proj.weight)
|
||||
nn.init.zeros_(self.after_proj.bias)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self.transformer_block.train(mode)
|
||||
|
||||
if self.block_index == 0:
|
||||
self.before_proj.train(mode)
|
||||
|
||||
self.after_proj.train(mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.block_index == 0:
|
||||
controlnet_states = self.before_proj(controlnet_states)
|
||||
controlnet_states = hidden_states + controlnet_states
|
||||
|
||||
controlnet_states_down = self.transformer_block(
|
||||
hidden_states=controlnet_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
class_labels=None,
|
||||
)
|
||||
|
||||
controlnet_states_left = self.after_proj(controlnet_states_down)
|
||||
|
||||
return controlnet_states_left, controlnet_states_down
|
||||
|
||||
|
||||
class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
|
||||
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
|
||||
@register_to_config
|
||||
def __init__(self, num_layers=13) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList(
|
||||
[PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_transformer(cls, transformer: PixArtTransformer2DModel):
|
||||
control_net = PixArtControlNetAdapterModel()
|
||||
|
||||
# copied the specified number of blocks from the transformer
|
||||
for depth in range(control_net.num_layers):
|
||||
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
|
||||
transformer.transformer_blocks[depth].state_dict()
|
||||
)
|
||||
|
||||
return control_net
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
for block in self.controlnet_blocks:
|
||||
block.train(mode)
|
||||
|
||||
|
||||
class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
transformer: PixArtTransformer2DModel,
|
||||
controlnet: PixArtControlNetAdapterModel,
|
||||
blocks_num=13,
|
||||
init_from_transformer=False,
|
||||
training=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks_num = blocks_num
|
||||
self.gradient_checkpointing = False
|
||||
self.register_to_config(**transformer.config)
|
||||
self.training = training
|
||||
|
||||
if init_from_transformer:
|
||||
# copies the specified number of blocks from the transformer
|
||||
controlnet.from_transformer(transformer, self.blocks_num)
|
||||
|
||||
self.transformer = transformer
|
||||
self.controlnet = controlnet
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
controlnet_cond: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
height, width = (
|
||||
hidden_states.shape[-2] // self.transformer.config.patch_size,
|
||||
hidden_states.shape[-1] // self.transformer.config.patch_size,
|
||||
)
|
||||
hidden_states = self.transformer.pos_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.transformer.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.transformer.caption_projection is not None:
|
||||
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
controlnet_states_down = None
|
||||
if controlnet_cond is not None:
|
||||
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
|
||||
|
||||
# 2. Blocks
|
||||
for block_index, block in enumerate(self.transformer.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# rc todo: for training and gradient checkpointing
|
||||
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
|
||||
exit(1)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
None,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# the control nets are only used for the blocks 1 to self.blocks_num
|
||||
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
|
||||
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
|
||||
block_index - 1
|
||||
](
|
||||
hidden_states=hidden_states, # used only in the first block
|
||||
controlnet_states=controlnet_states_down,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + controlnet_states_left
|
||||
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=None,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (
|
||||
self.transformer.scale_shift_table[None]
|
||||
+ embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
|
||||
).chunk(2, dim=1)
|
||||
hidden_states = self.transformer.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
||||
hidden_states = self.transformer.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
height,
|
||||
width,
|
||||
self.transformer.config.patch_size,
|
||||
self.transformer.config.patch_size,
|
||||
self.transformer.out_channels,
|
||||
)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
self.transformer.out_channels,
|
||||
height * self.transformer.config.patch_size,
|
||||
width * self.transformer.config.patch_size,
|
||||
)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
transformers
|
||||
SentencePiece
|
||||
torchvision
|
||||
controlnet-aux
|
||||
datasets
|
||||
# wandb
|
||||
@@ -1,75 +0,0 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from controlnet_aux import HEDdetector
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
|
||||
from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
|
||||
|
||||
|
||||
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
|
||||
|
||||
weight_dtype = torch.float16
|
||||
image_size = 1024
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# load controlnet
|
||||
controlnet = PixArtControlNetAdapterModel.from_pretrained(
|
||||
controlnet_repo_id,
|
||||
torch_dtype=weight_dtype,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
|
||||
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=weight_dtype,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
|
||||
images_path = "images"
|
||||
control_image_file = "0_7.jpg"
|
||||
|
||||
# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
|
||||
# prompt = "yellow modern car, city in background, beautiful rainy day"
|
||||
# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
|
||||
# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
|
||||
# prompt = "purple car, on highway, beautiful sunny day"
|
||||
# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ."
|
||||
prompt = "battleship in space, galaxy in background"
|
||||
|
||||
control_image_name = control_image_file.split(".")[0]
|
||||
|
||||
control_image = load_image(f"{images_path}/{control_image_file}")
|
||||
print(control_image.size)
|
||||
height, width = control_image.size
|
||||
|
||||
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
condition_transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB")),
|
||||
T.CenterCrop([image_size, image_size]),
|
||||
]
|
||||
)
|
||||
|
||||
control_image = condition_transform(control_image)
|
||||
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
|
||||
|
||||
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
|
||||
|
||||
# run pipeline
|
||||
with torch.no_grad():
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
image=hed_edge,
|
||||
num_inference_steps=14,
|
||||
guidance_scale=4.5,
|
||||
height=image_size,
|
||||
width=image_size,
|
||||
)
|
||||
|
||||
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
|
||||
@@ -1,23 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# run
|
||||
# accelerate config
|
||||
|
||||
# check with
|
||||
# accelerate env
|
||||
|
||||
export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512"
|
||||
export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test"
|
||||
|
||||
accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--seed=42 \
|
||||
--dataloader_num_workers=8
|
||||
# --lr_scheduler="cosine" --lr_warmup_steps=0 \
|
||||
File diff suppressed because it is too large
Load Diff
@@ -229,11 +229,11 @@ class PromptDiffusionControlNetModel(ControlNetModel):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
|
||||
+1
@@ -1475,6 +1475,7 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -16,8 +16,7 @@ TOKENIZER_MAX_LENGTH = 256
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", required=True, type=str)
|
||||
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
||||
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
||||
@@ -34,12 +33,6 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def swap_proj_gate(weight):
|
||||
proj, gate = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([gate, proj], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
original_state_dict = load_file(ckpt_path, device="cpu")
|
||||
new_state_dict = {}
|
||||
@@ -126,13 +119,11 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
)
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
original_state_dict.pop(old_prefix + "mlp_y.w1.weight")
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp_y.w1.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp_y.w2.weight"
|
||||
@@ -153,255 +144,104 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
|
||||
encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
|
||||
decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
|
||||
new_state_dict = {}
|
||||
def convert_mochi_decoder_state_dict_to_diffusers(ckpt_path):
|
||||
original_state_dict = load_file(ckpt_path, device="cpu")
|
||||
|
||||
# ==== Decoder =====
|
||||
new_state_dict = {}
|
||||
prefix = "decoder."
|
||||
|
||||
# Convert conv_in
|
||||
new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight")
|
||||
new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias")
|
||||
new_state_dict[f"{prefix}conv_in.weight"] = original_state_dict["blocks.0.0.weight"]
|
||||
new_state_dict[f"{prefix}conv_in.bias"] = original_state_dict["blocks.0.0.bias"]
|
||||
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = original_state_dict[
|
||||
f"blocks.0.{i+1}.stack.5.bias"
|
||||
)
|
||||
]
|
||||
|
||||
# Convert up_blocks (MochiUpBlock3D)
|
||||
down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
|
||||
up_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
|
||||
for block in range(3):
|
||||
for i in range(down_block_layers[block]):
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
for i in range(up_block_layers[block]):
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = original_state_dict[
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.proj.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
|
||||
]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = original_state_dict[f"blocks.{block+1}.proj.weight"]
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = original_state_dict[f"blocks.{block+1}.proj.bias"]
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
]
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = original_state_dict[
|
||||
f"blocks.4.{i}.stack.5.bias"
|
||||
)
|
||||
]
|
||||
|
||||
# Convert proj_out (Conv1x1 ~= nn.Linear)
|
||||
new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight")
|
||||
new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias")
|
||||
|
||||
print("Remaining Decoder Keys:", decoder_state_dict.keys())
|
||||
|
||||
# ==== Encoder =====
|
||||
prefix = "encoder."
|
||||
|
||||
new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight")
|
||||
new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias")
|
||||
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert down_blocks (MochiDownBlock3D)
|
||||
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
|
||||
for block in range(3):
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.bias"
|
||||
)
|
||||
|
||||
for i in range(down_block_layers[block]):
|
||||
# Convert resnets
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert output layers
|
||||
new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight")
|
||||
new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias")
|
||||
new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight")
|
||||
|
||||
print("Remaining Encoder Keys:", encoder_state_dict.keys())
|
||||
# Convert conv_out (Conv1x1)
|
||||
new_state_dict[f"{prefix}conv_out.weight"] = original_state_dict["output_proj.weight"]
|
||||
new_state_dict[f"{prefix}conv_out.bias"] = original_state_dict["output_proj.bias"]
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -430,11 +270,9 @@ def main(args):
|
||||
if dtype is not None:
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
|
||||
if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
|
||||
if args.vae_checkpoint_path is not None:
|
||||
vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
|
||||
converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
|
||||
args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
|
||||
)
|
||||
converted_vae_state_dict = convert_mochi_decoder_state_dict_to_diffusers(args.vae_checkpoint_path)
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
if dtype is not None:
|
||||
vae = vae.to(dtype=dtype)
|
||||
@@ -448,7 +286,7 @@ def main(args):
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
pipe = MochiPipeline(
|
||||
scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
|
||||
scheduler=FlowMatchEulerDiscreteScheduler(),
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -77,11 +77,9 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
@@ -241,7 +239,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AllegroPipeline",
|
||||
"AltDiffusionImg2ImgPipeline",
|
||||
"AltDiffusionPipeline",
|
||||
"AmusedImg2ImgPipeline",
|
||||
@@ -487,7 +484,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
|
||||
else:
|
||||
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
@@ -562,11 +559,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
@@ -707,7 +702,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
AllegroPipeline,
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AmusedImg2ImgPipeline,
|
||||
@@ -914,7 +908,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .models.controlnets.controlnet_flax import FlaxControlNetModel
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
|
||||
@@ -33,14 +33,16 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -282,9 +284,7 @@ class IPAdapterMixin:
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
@@ -342,9 +342,7 @@ class IPAdapterMixin:
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
+118
-124
@@ -51,9 +51,6 @@ if is_accelerate_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
@@ -184,119 +181,6 @@ def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
@@ -350,16 +234,124 @@ class LoraBaseMixin:
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
def _fetch_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
def _best_guess_weight_name(
|
||||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [
|
||||
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
||||
]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
@@ -733,6 +725,8 @@ class LoraBaseMixin:
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
@@ -32,7 +33,7 @@ from ..utils import (
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
)
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
|
||||
from .lora_base import LoraBaseMixin
|
||||
from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
@@ -61,6 +62,9 @@ TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
|
||||
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
@@ -218,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict = cls._fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -278,9 +282,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -339,9 +341,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -601,9 +601,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -746,7 +744,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict = cls._fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -807,9 +805,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -869,9 +865,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -1188,7 +1182,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict = cls._fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -1232,9 +1226,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -1258,17 +1250,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -1313,24 +1301,94 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -1366,9 +1424,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -1686,7 +1742,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict = cls._fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -1763,9 +1819,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -1789,18 +1843,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -1831,32 +1881,104 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
transformer (`FluxTransformer2DModel`):
|
||||
transformer (`SD3Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
prefix = cls.transformer_name
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -1892,9 +2014,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -2122,10 +2242,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -2138,32 +2255,93 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
transformer (`UVit2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -2199,9 +2377,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -2443,7 +2619,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict = cls._fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -2482,9 +2658,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -2517,7 +2691,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
@@ -2529,29 +2703,99 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`CogVideoXTransformer3DModel`):
|
||||
transformer (`SD3Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
||||
|
||||
@@ -16,32 +16,18 @@ import inspect
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
check_peft_version,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
||||
"UNetMotionModel": _maybe_expand_lora_scales,
|
||||
@@ -67,215 +53,6 @@ class PeftAdapterMixin:
|
||||
|
||||
_hf_peft_config_loaded = False
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
r"""
|
||||
Loads a LoRA adapter into the underlying model.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
prefix (`str`, *optional*): Prefix to filter the state dict.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
||||
if len(transformer_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
|
||||
@@ -765,7 +765,6 @@ class UNet2DConditionLoadersMixin:
|
||||
from ..models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -805,15 +804,11 @@ class UNet2DConditionLoadersMixin:
|
||||
if cross_attention_dim is None or "motion_modules" in name:
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
if "XFormers" in str(self.attn_processors[name].__class__):
|
||||
attn_processor_class = IPAdapterXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else IPAdapterAttnProcessor
|
||||
)
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
|
||||
@@ -28,7 +28,6 @@ if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
@@ -36,16 +35,12 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_hunyuan"] = [
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
]
|
||||
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
||||
@@ -60,7 +55,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
@@ -78,7 +72,7 @@ if is_torch_available():
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
|
||||
@@ -89,7 +83,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
@@ -98,23 +91,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
)
|
||||
from .controlnets import (
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
MultiControlNetModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SparseControlNetModel,
|
||||
UNetControlNetXSModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import SparseControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
@@ -148,7 +133,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnets import FlaxControlNetModel
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unets import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
|
||||
@@ -134,8 +134,9 @@ class SwiGLU(nn.Module):
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
|
||||
super().__init__()
|
||||
self.flip_gate = flip_gate
|
||||
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
||||
self.activation = nn.SiLU()
|
||||
@@ -143,6 +144,8 @@ class SwiGLU(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
||||
if self.flip_gate:
|
||||
hidden_states, gate = gate, hidden_states
|
||||
return hidden_states * self.activation(gate)
|
||||
|
||||
|
||||
|
||||
@@ -1206,6 +1206,7 @@ class FeedForward(nn.Module):
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
flip_gate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
@@ -1221,7 +1222,7 @@ class FeedForward(nn.Module):
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias, flip_gate=flip_gate)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
|
||||
@@ -124,12 +124,11 @@ class Attention(nn.Module):
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import.
|
||||
from .normalization import FP32LayerNorm, LpNorm, RMSNorm
|
||||
from .normalization import FP32LayerNorm, RMSNorm
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
@@ -147,7 +146,6 @@ class Attention(nn.Module):
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
@@ -196,9 +194,6 @@ class Attention(nn.Module):
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "l2":
|
||||
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
||||
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
||||
else:
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
||||
|
||||
@@ -318,10 +313,7 @@ class Attention(nn.Module):
|
||||
XFormersAttnAddedKVProcessor,
|
||||
),
|
||||
)
|
||||
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and is_custom_diffusion:
|
||||
raise NotImplementedError(
|
||||
@@ -371,19 +363,6 @@ class Attention(nn.Module):
|
||||
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
||||
)
|
||||
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterXFormersAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
scale=self.processor.scale,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -1544,100 +1523,6 @@ class FusedJointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class AllegroAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None and not attn.is_cross_attention:
|
||||
from .embeddings import apply_rotary_emb_allegro
|
||||
|
||||
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
|
||||
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
@@ -1917,112 +1802,6 @@ class FluxAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -2117,117 +1896,6 @@ class FusedFluxAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -2738,91 +2406,6 @@ class AttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiVaeAttnProcessor2_0:
|
||||
r"""
|
||||
Attention processor used in Mochi VAE.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
is_single_frame = hidden_states.shape[1] == 1
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if is_single_frame:
|
||||
hidden_states = attn.to_v(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableAudioAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
@@ -4558,238 +4141,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapter using xFormers.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim=None,
|
||||
num_tokens=(4,),
|
||||
scale=1.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_hidden_states:
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
raise ValueError(
|
||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
||||
f"({len(ip_hidden_states)})"
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(
|
||||
zip(ip_adapter_masks, self.scale, ip_hidden_states)
|
||||
):
|
||||
if mask is None:
|
||||
continue
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
"[1, num_images_for_ip_adapter, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if mask.shape[1] != ip_state.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
||||
)
|
||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of scales ({len(scale)}) at index {index}"
|
||||
)
|
||||
else:
|
||||
ip_adapter_masks = [None] * len(self.scale)
|
||||
|
||||
# for ip-adapter
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
skip = False
|
||||
if isinstance(scale, list):
|
||||
if all(s == 0 for s in scale):
|
||||
skip = True
|
||||
elif scale == 0:
|
||||
skip = True
|
||||
if not skip:
|
||||
if mask is not None:
|
||||
mask = mask.to(torch.float16)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * mask.shape[1]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -94,13 +94,11 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
||||
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
||||
time_pad = time_kernel_size - 1
|
||||
height_pad = (height_kernel_size - 1) // 2
|
||||
width_pad = (width_kernel_size - 1) // 2
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
||||
height_pad = height_kernel_size // 2
|
||||
width_pad = width_kernel_size // 2
|
||||
|
||||
self.height_pad = height_pad
|
||||
self.width_pad = width_pad
|
||||
self.time_pad = time_pad
|
||||
@@ -109,7 +107,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self.temporal_dim = 2
|
||||
self.time_kernel_size = time_kernel_size
|
||||
|
||||
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
||||
stride = (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = CogVideoXSafeConv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -122,24 +120,18 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
def fake_context_parallel_forward(
|
||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.pad_mode == "replicate":
|
||||
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
||||
else:
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
|
||||
if self.pad_mode == "replicate":
|
||||
conv_cache = None
|
||||
else:
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
return output, conv_cache
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,32 +11,860 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
BaseOutput,
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
zero_module,
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders.single_file_model import FromOriginalModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
class ControlNetOutput(ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
|
||||
deprecate("ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ControlNetModel(ControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
|
||||
deprecate("ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`ControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
|
||||
deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2D":
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
num_layers=0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
add_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
mid_block_type=unet.config.mid_block_type,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if controlnet.class_embedding:
|
||||
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
|
||||
if hasattr(controlnet, "add_embedding"):
|
||||
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
+5
-5
@@ -19,11 +19,11 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..unets.unet_2d_blocks_flax import (
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .unets.unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
@@ -12,30 +12,525 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
|
||||
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from .modeling_outputs import Transformer2DModelOutput
|
||||
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxControlNetOutput(FluxControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
|
||||
deprecate("FluxControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(FluxControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
|
||||
deprecate("FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
self.time_text_embed = text_time_guidance_cls(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.union = num_mode is not None
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
if conditioning_embedding_channels is not None:
|
||||
self.input_hint_block = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
|
||||
)
|
||||
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
else:
|
||||
self.input_hint_block = None
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers: int = 4,
|
||||
num_single_layers: int = 10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor = None,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_mode (`torch.Tensor`):
|
||||
The mode tensor of shape `(batch_size, 1)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
|
||||
height = height_pw // self.config.patch_size
|
||||
width = width_pw // self.config.patch_size
|
||||
controlnet_cond = controlnet_cond.reshape(
|
||||
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
|
||||
)
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
|
||||
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return FluxControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(FluxMultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
|
||||
deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class FluxMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
|
||||
|
||||
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
+7
-7
@@ -17,17 +17,17 @@ from typing import Dict, Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import (
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .attention_processor import AttentionProcessor
|
||||
from .controlnet import BaseOutput, Tuple, zero_module
|
||||
from .embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
from .controlnet import BaseOutput, Tuple, zero_module
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -13,29 +13,410 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ..models.attention import JointTransformerBlock
|
||||
from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3ControlNetOutput(SD3ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
|
||||
deprecate("SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@dataclass
|
||||
class SD3ControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SD3ControlNetModel(SD3ControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
|
||||
deprecate("SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
pos_embed_input = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels + extra_conditioning_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
self.pos_embed_input = zero_module(pos_embed_input)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers or config.num_layers
|
||||
config["extra_conditioning_channels"] = num_extra_conditioning_channels
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
|
||||
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
# 6. scaling
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(SD3MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
|
||||
deprecate("SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class SD3MultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
|
||||
|
||||
This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `SD3ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[SD3ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`SD3ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
pooled_projections: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SD3ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
pooled_projections=pooled_projections,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
||||
]
|
||||
control_block_samples = (tuple(control_block_samples),)
|
||||
|
||||
return control_block_samples
|
||||
|
||||
@@ -12,35 +12,777 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
SparseControlNetConditioningEmbedding,
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
zero_module,
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SparseControlNetOutput(SparseControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
|
||||
deprecate("SparseControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@dataclass
|
||||
class SparseControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`SparseControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
|
||||
deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseControlNetModel(SparseControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
|
||||
deprecate("SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
|
||||
Models](https://arxiv.org/abs/2311.16933).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
conditioning_channels (`int`, defaults to 4):
|
||||
The number of input channels in the controlnet conditional embedding module. If
|
||||
`concat_condition_embedding` is True, the value provided here is incremented by 1.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer layers to use in each layer in the middle block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
The number of heads to use for multi-head attention.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
|
||||
motion_max_seq_length (`int`, defaults to `32`):
|
||||
The maximum sequence length to use in the motion module.
|
||||
motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
|
||||
The number of heads to use in each attention layer of the motion module.
|
||||
concat_conditioning_mask (`bool`, defaults to `True`):
|
||||
use_simplified_condition_embedding (`bool`, defaults to `True`):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
if isinstance(temporal_transformer_layers_per_block, int):
|
||||
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
if concat_conditioning_mask:
|
||||
conditioning_channels = conditioning_channels + 1
|
||||
|
||||
self.concat_conditioning_mask = concat_conditioning_mask
|
||||
|
||||
# control net conditioning embedding
|
||||
if use_simplified_condition_embedding:
|
||||
self.controlnet_cond_embedding = zero_module(
|
||||
nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
)
|
||||
else:
|
||||
self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(motion_num_attention_heads, int):
|
||||
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if down_block_type == "CrossAttnDownBlockMotion":
|
||||
down_block = CrossAttnDownBlockMotion(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
elif down_block_type == "DownBlockMotion":
|
||||
down_block = DownBlockMotion(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=layers_per_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
add_downsample=not is_final_block,
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channels = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
if transformer_layers_per_mid_block is None:
|
||||
transformer_layers_per_mid_block = (
|
||||
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
|
||||
)
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=mid_block_channels,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=transformer_layers_per_mid_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type="default",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
) -> "SparseControlNetModel":
|
||||
r"""
|
||||
Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
|
||||
copied where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
down_block_types = unet.config.down_block_types
|
||||
|
||||
for i in range(len(down_block_types)):
|
||||
if "CrossAttn" in down_block_types[i]:
|
||||
down_block_types[i] = "CrossAttnDownBlockMotion"
|
||||
elif "Down" in down_block_types[i]:
|
||||
down_block_types[i] = "DownBlockMotion"
|
||||
else:
|
||||
raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
|
||||
|
||||
controlnet = cls(
|
||||
in_channels=unet.config.in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
conditioning_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`SparseControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
|
||||
sample = torch.zeros_like(sample)
|
||||
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
batch_frames, channels, height, width = sample.shape
|
||||
sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
|
||||
|
||||
if self.concat_conditioning_mask:
|
||||
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
||||
|
||||
batch_size, channels, num_frames, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
|
||||
batch_size * num_frames, channels, height, width
|
||||
)
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
batch_frames, channels, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
batch_size, num_frames, channels, height, width = sample.shape
|
||||
sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return SparseControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.models.controlnet.zero_module
|
||||
def zero_module(module: nn.Module) -> nn.Module:
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
+10
-11
@@ -19,10 +19,10 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention_processor import (
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, is_torch_version, logging
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
@@ -31,9 +31,10 @@ from ..attention_processor import (
|
||||
AttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import (
|
||||
from .controlnet import ControlNetConditioningEmbedding
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
Downsample2D,
|
||||
@@ -42,8 +43,7 @@ from ..unets.unet_2d_blocks import (
|
||||
UNetMidBlock2DCrossAttn,
|
||||
Upsample2D,
|
||||
)
|
||||
from ..unets.unet_2d_condition import UNet2DConditionModel
|
||||
from .controlnet import ControlNetConditioningEmbedding
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -1062,8 +1062,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
|
||||
tuple.
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
apply_control (`bool`, defaults to `True`):
|
||||
If `False`, the input is run only through the base model.
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from ...utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .controlnet import ControlNetModel, ControlNetOutput
|
||||
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
from .controlnet_hunyuan import (
|
||||
HunyuanControlNetOutput,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
)
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import (
|
||||
SparseControlNetConditioningEmbedding,
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
)
|
||||
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
@@ -1,872 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from ..unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`ControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2D":
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
num_layers=0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
add_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
mid_block_type=unet.config.mid_block_type,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if controlnet.class_embedding:
|
||||
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
|
||||
if hasattr(controlnet, "add_embedding"):
|
||||
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned,
|
||||
otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -1,536 +0,0 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention_processor import AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
self.time_text_embed = text_time_guidance_cls(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.union = num_mode is not None
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
if conditioning_embedding_channels is not None:
|
||||
self.input_hint_block = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
|
||||
)
|
||||
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
else:
|
||||
self.input_hint_block = None
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers: int = 4,
|
||||
num_single_layers: int = 10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor = None,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_mode (`torch.Tensor`):
|
||||
The mode tensor of shape `(batch_size, 1)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
|
||||
height = height_pw // self.config.patch_size
|
||||
width = width_pw // self.config.patch_size
|
||||
controlnet_cond = controlnet_cond.reshape(
|
||||
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
|
||||
)
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
|
||||
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return FluxControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
|
||||
|
||||
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
@@ -1,422 +0,0 @@
|
||||
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import JointTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
pos_embed_input = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels + extra_conditioning_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
self.pos_embed_input = zero_module(pos_embed_input)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers or config.num_layers
|
||||
config["extra_conditioning_channels"] = num_extra_conditioning_channels
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
|
||||
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
# 6. scaling
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
|
||||
|
||||
This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `SD3ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[SD3ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`SD3ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
pooled_projections: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SD3ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
pooled_projections=pooled_projections,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
||||
]
|
||||
control_block_samples = (tuple(control_block_samples),)
|
||||
|
||||
return control_block_samples
|
||||
@@ -1,788 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from ..unets.unet_2d_condition import UNet2DConditionModel
|
||||
from ..unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`SparseControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
|
||||
Models](https://arxiv.org/abs/2311.16933).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
conditioning_channels (`int`, defaults to 4):
|
||||
The number of input channels in the controlnet conditional embedding module. If
|
||||
`concat_condition_embedding` is True, the value provided here is incremented by 1.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer layers to use in each layer in the middle block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
The number of heads to use for multi-head attention.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
|
||||
motion_max_seq_length (`int`, defaults to `32`):
|
||||
The maximum sequence length to use in the motion module.
|
||||
motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
|
||||
The number of heads to use in each attention layer of the motion module.
|
||||
concat_conditioning_mask (`bool`, defaults to `True`):
|
||||
use_simplified_condition_embedding (`bool`, defaults to `True`):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
if isinstance(temporal_transformer_layers_per_block, int):
|
||||
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
if concat_conditioning_mask:
|
||||
conditioning_channels = conditioning_channels + 1
|
||||
|
||||
self.concat_conditioning_mask = concat_conditioning_mask
|
||||
|
||||
# control net conditioning embedding
|
||||
if use_simplified_condition_embedding:
|
||||
self.controlnet_cond_embedding = zero_module(
|
||||
nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
)
|
||||
else:
|
||||
self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(motion_num_attention_heads, int):
|
||||
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if down_block_type == "CrossAttnDownBlockMotion":
|
||||
down_block = CrossAttnDownBlockMotion(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
elif down_block_type == "DownBlockMotion":
|
||||
down_block = DownBlockMotion(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=layers_per_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
add_downsample=not is_final_block,
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channels = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
if transformer_layers_per_mid_block is None:
|
||||
transformer_layers_per_mid_block = (
|
||||
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
|
||||
)
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=mid_block_channels,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=transformer_layers_per_mid_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_pre_norm=True,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type="default",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
) -> "SparseControlNetModel":
|
||||
r"""
|
||||
Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
|
||||
copied where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
down_block_types = unet.config.down_block_types
|
||||
|
||||
for i in range(len(down_block_types)):
|
||||
if "CrossAttn" in down_block_types[i]:
|
||||
down_block_types[i] = "CrossAttnDownBlockMotion"
|
||||
elif "Down" in down_block_types[i]:
|
||||
down_block_types[i] = "DownBlockMotion"
|
||||
else:
|
||||
raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
|
||||
|
||||
controlnet = cls(
|
||||
in_channels=unet.config.in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
conditioning_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`SparseControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
|
||||
sample = torch.zeros_like(sample)
|
||||
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
batch_frames, channels, height, width = sample.shape
|
||||
sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
|
||||
|
||||
if self.concat_conditioning_mask:
|
||||
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
||||
|
||||
batch_size, channels, num_frames, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
|
||||
batch_size * num_frames, channels, height, width
|
||||
)
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
batch_frames, channels, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
batch_size, num_frames, channels, height, width = sample.shape
|
||||
sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return SparseControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.models.controlnets.controlnet.zero_module
|
||||
def zero_module(module: nn.Module) -> nn.Module:
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -1,183 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
|
||||
|
||||
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
class_labels=class_labels,
|
||||
timestep_cond=timestep_cond,
|
||||
attention_mask=attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
for idx, controlnet in enumerate(self.nets):
|
||||
suffix = "" if idx == 0 else f"_{idx}"
|
||||
controlnet.save_pretrained(
|
||||
save_directory + suffix,
|
||||
is_main_process=is_main_process,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
controlnets = []
|
||||
|
||||
# load controlnet and append to list until no controlnet directory exists anymore
|
||||
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
|
||||
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
|
||||
model_path_to_load = pretrained_model_path
|
||||
while os.path.isdir(model_path_to_load):
|
||||
controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
|
||||
controlnets.append(controlnet)
|
||||
|
||||
idx += 1
|
||||
model_path_to_load = pretrained_model_path + f"_{idx}"
|
||||
|
||||
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
|
||||
|
||||
if len(controlnets) == 0:
|
||||
raise ValueError(
|
||||
f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
||||
)
|
||||
|
||||
return cls(controlnets)
|
||||
@@ -564,42 +564,6 @@ def get_3d_rotary_pos_embed(
|
||||
return cos, sin
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed_allegro(
|
||||
embed_dim,
|
||||
crops_coords,
|
||||
grid_size,
|
||||
temporal_size,
|
||||
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
theta: int = 10000,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# TODO(aryan): docs
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 3
|
||||
dim_h = embed_dim // 3
|
||||
dim_w = embed_dim // 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(
|
||||
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
|
||||
)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(
|
||||
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
|
||||
)
|
||||
freqs_w = get_1d_rotary_pos_embed(
|
||||
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
|
||||
)
|
||||
|
||||
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
@@ -720,7 +684,7 @@ def get_1d_rotary_pos_embed(
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
# stable audio
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
@@ -779,24 +743,6 @@ def apply_rotary_emb(
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
||||
# TODO(aryan): rewrite
|
||||
def apply_1d_rope(tokens, pos, cos, sin):
|
||||
cos = F.embedding(pos, cos)[:, None, :, :]
|
||||
sin = F.embedding(pos, sin)[:, None, :, :]
|
||||
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
|
||||
tokens_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
|
||||
|
||||
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
|
||||
t, h, w = x.chunk(3, dim=-1)
|
||||
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
|
||||
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
|
||||
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
|
||||
x = torch.cat([t, h, w], dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
|
||||
@@ -22,7 +22,10 @@ import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from .embeddings import (
|
||||
CombinedTimestepLabelEmbeddings,
|
||||
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
@@ -290,7 +293,6 @@ class AdaLayerNormSingle(nn.Module):
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
@@ -390,7 +392,7 @@ class LuminaLayerNormContinuous(nn.Module):
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
@@ -554,15 +556,3 @@ class GlobalResponseNorm(nn.Module):
|
||||
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * nx) + self.beta + x
|
||||
|
||||
|
||||
class LpNorm(nn.Module):
|
||||
def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
|
||||
self.p = p
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
|
||||
|
||||
@@ -14,7 +14,6 @@ if is_torch_available():
|
||||
from .stable_audio_transformer import StableAudioDiTModel
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
# Copyright 2024 The RhymesAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AllegroTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
cross_attention_dim (`int`, defaults to `2304`):
|
||||
The dimension of the cross attention features.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
only_cross_attention (`bool`, defaults to `False`):
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=AllegroAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
processor=AllegroAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed Forward
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
)
|
||||
|
||||
# 4. Scale-shift
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb=None,
|
||||
) -> torch.Tensor:
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = hidden_states
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
image_rotary_emb=None,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
# TODO(aryan): maybe following line is not required
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
A 3D Transformer model for video-like data.
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of temporal patches to use in the patch embedding layer.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `96`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `4`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `4`):
|
||||
The number of channels in the output.
|
||||
num_layers (`int`, defaults to `32`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
cross_attention_dim (`int`, defaults to `2304`):
|
||||
The dimension of the cross attention features.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_height (`int`, defaults to `90`):
|
||||
The height of the input latents.
|
||||
sample_width (`int`, defaults to `160`):
|
||||
The width of the input latents.
|
||||
sample_frames (`int`, defaults to `22`):
|
||||
The number of frames in the input latents.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value to use in normalization layers.
|
||||
caption_channels (`int`, defaults to `4096`):
|
||||
Number of channels to use for projecting the caption embeddings.
|
||||
interpolation_scale_h (`float`, defaults to `2.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across height dimension.
|
||||
interpolation_scale_w (`float`, defaults to `2.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across width dimension.
|
||||
interpolation_scale_t (`float`, defaults to `2.2`):
|
||||
Scaling factor to apply in 3D positional embeddings across time dimension.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
patch_size_t: int = 1,
|
||||
num_attention_heads: int = 24,
|
||||
attention_head_dim: int = 96,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
num_layers: int = 32,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: int = 2304,
|
||||
attention_bias: bool = True,
|
||||
sample_height: int = 90,
|
||||
sample_width: int = 160,
|
||||
sample_frames: int = 22,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
caption_channels: int = 4096,
|
||||
interpolation_scale_h: float = 2.0,
|
||||
interpolation_scale_w: float = 2.0,
|
||||
interpolation_scale_t: float = 2.2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
interpolation_scale_t = (
|
||||
interpolation_scale_t
|
||||
if interpolation_scale_t is not None
|
||||
else ((sample_frames - 1) // 16 + 1)
|
||||
if sample_frames % 2 == 1
|
||||
else sample_frames // 16
|
||||
)
|
||||
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
|
||||
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
|
||||
|
||||
# 1. Patch embedding
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_height,
|
||||
width=sample_width,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
# 2. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
AllegroTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output projection & norm
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
# 4. Timestep embeddings
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
|
||||
|
||||
# 5. Caption projection
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t = self.config.patch_size_t
|
||||
p = self.config.patch_size
|
||||
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
# b, frame+use_image_num, h, w -> a video with images
|
||||
# b, 1, h, w -> only images
|
||||
attention_mask = attention_mask.to(hidden_states.dtype)
|
||||
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
|
||||
|
||||
if attention_mask.numel() > 0:
|
||||
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
|
||||
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
|
||||
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
|
||||
|
||||
attention_mask = (
|
||||
(1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
|
||||
)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Timestep embeddings
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Patch embeddings
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
|
||||
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
# TODO(aryan): Implement gradient checkpointing
|
||||
if self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=timestep,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# 4. Output normalization & projection
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -27,13 +27,11 @@ from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -66,10 +64,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
processor = FluxAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
|
||||
@@ -109,7 +109,9 @@ class MochiTransformerBlock(nn.Module):
|
||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
||||
self.ff = FeedForward(
|
||||
dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
|
||||
)
|
||||
self.ff_context = None
|
||||
if not context_pre_only:
|
||||
self.ff_context = FeedForward(
|
||||
@@ -117,6 +119,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
inner_dim=self.ff_context_inner_dim,
|
||||
activation_fn=activation_fn,
|
||||
bias=False,
|
||||
flip_gate=True,
|
||||
)
|
||||
|
||||
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
@@ -289,7 +292,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
num_attention_heads=8,
|
||||
)
|
||||
|
||||
self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
|
||||
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
|
||||
self.rope = MochiRoPE()
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -350,7 +353,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -116,7 +116,6 @@ else:
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["allegro"] = ["AllegroPipeline"]
|
||||
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
||||
_import_structure["animatediff"] = [
|
||||
"AnimateDiffPipeline",
|
||||
@@ -456,7 +455,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .allegro import AllegroPipeline
|
||||
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
||||
from .animatediff import (
|
||||
AnimateDiffControlNetPipeline,
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_allegro"] = ["AllegroPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_allegro import AllegroPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,918 +0,0 @@
|
||||
# Copyright 2024 The RhymesAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed_allegro
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
deprecate,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import AllegroPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AutoencoderKLAllegro, AllegroPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
>>> prompt = (
|
||||
... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
|
||||
... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
|
||||
... "location might be a popular spot for docking fishing boats."
|
||||
... )
|
||||
>>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AllegroPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Allegro.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AllegroAutoEncoderKL3D`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. PixArt-Alpha uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`AllegroTransformer3DModel`]):
|
||||
A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"["
|
||||
+ "#®•©™&@·º½¾¿¡§~"
|
||||
+ r"\)"
|
||||
+ r"\("
|
||||
+ r"\]"
|
||||
+ r"\["
|
||||
+ r"\}"
|
||||
+ r"\{"
|
||||
+ r"\|"
|
||||
+ "\\"
|
||||
+ r"\/"
|
||||
+ r"\*"
|
||||
+ r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLAllegro,
|
||||
transformer: AllegroTransformer3DModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
PixArt-Alpha, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
|
||||
if "mask_feature" in kwargs:
|
||||
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
||||
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = max_sequence_length
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because T5 can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if num_frames <= 0:
|
||||
raise ValueError(f"`num_frames` have to be positive but is {num_frames}.")
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if num_frames % 2 == 0:
|
||||
num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal)
|
||||
else:
|
||||
num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
frames = self.vae.decode(latents).sample
|
||||
frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width]
|
||||
return frames
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
batch_size: int,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
):
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
start, stop = (0, 0), (grid_height, grid_width)
|
||||
freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=(start, stop),
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
interpolation_scale=(
|
||||
self.transformer.config.interpolation_scale_t,
|
||||
self.transformer.config.interpolation_scale_h,
|
||||
self.transformer.config.interpolation_scale_w,
|
||||
),
|
||||
)
|
||||
|
||||
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long)
|
||||
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long)
|
||||
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long)
|
||||
|
||||
pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
|
||||
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
|
||||
grid_t, grid_h, grid_w = pos
|
||||
|
||||
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
|
||||
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
|
||||
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
|
||||
|
||||
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 100,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
clean_caption: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
) -> Union[AllegroPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower video quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
num_frames: (`int`, *optional*, defaults to 88):
|
||||
The number controls the generated video frames.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated video.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate video. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
max_sequence_length (`int` defaults to `512`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated videos.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal
|
||||
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare rotary embeddings
|
||||
image_rotary_emb = self._prepare_rotary_positional_embeddings(
|
||||
batch_size, height, width, latents.size(2), device
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.decode_latents(latents)
|
||||
video = video[:, :, :num_frames, :height, :width]
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AllegroPipelineOutput(frames=video)
|
||||
@@ -1,23 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllegroPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Allegro pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
|
||||
from ...models.controlnet_sparsectrl import SparseControlNetModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
|
||||
@@ -1,12 +1,183 @@
|
||||
from ...models.controlnets.multicontrolnet import MultiControlNetModel
|
||||
from ...utils import deprecate, logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MultiControlNetModel(MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
|
||||
deprecate("MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
class MultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
|
||||
|
||||
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
class_labels=class_labels,
|
||||
timestep_cond=timestep_cond,
|
||||
attention_mask=attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
for idx, controlnet in enumerate(self.nets):
|
||||
suffix = "" if idx == 0 else f"_{idx}"
|
||||
controlnet.save_pretrained(
|
||||
save_directory + suffix,
|
||||
is_main_process=is_main_process,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
controlnets = []
|
||||
|
||||
# load controlnet and append to list until no controlnet directory exists anymore
|
||||
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
|
||||
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
|
||||
model_path_to_load = pretrained_model_path
|
||||
while os.path.isdir(model_path_to_load):
|
||||
controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
|
||||
controlnets.append(controlnet)
|
||||
|
||||
idx += 1
|
||||
model_path_to_load = pretrained_model_path + f"_{idx}"
|
||||
|
||||
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
|
||||
|
||||
if len(controlnets) == 0:
|
||||
raise ValueError(
|
||||
f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
||||
)
|
||||
|
||||
return cls(controlnets)
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
|
||||
+1
-1
@@ -26,7 +26,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
|
||||
@@ -371,7 +371,7 @@ class FluxPipeline(
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -427,7 +427,7 @@ class FluxPipeline(
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -437,7 +437,7 @@ class FluxPipeline(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
|
||||
@@ -27,7 +27,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
@@ -452,7 +452,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -462,7 +462,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
@@ -407,7 +407,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -495,7 +495,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -505,7 +505,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
|
||||
@@ -14,7 +14,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
@@ -417,7 +417,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -522,7 +522,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -532,7 +532,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
|
||||
@@ -391,7 +391,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -479,7 +479,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -489,7 +489,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
|
||||
@@ -395,7 +395,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -500,7 +500,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
@@ -510,7 +510,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
|
||||
@@ -52,8 +52,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> pipe.enable_vae_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
|
||||
>>> export_to_video(frames, "mochi.mp4")
|
||||
@@ -87,7 +86,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
@@ -347,6 +346,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -355,6 +355,8 @@ class MochiPipeline(DiffusionPipeline):
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if (num_frames - 1) % self.vae_temporal_scale_factor != 0:
|
||||
raise ValueError(f"Expected `num_frames - 1` to be divisible by {self.vae_temporal_scale_factor=}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -477,7 +479,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 19,
|
||||
num_frames: int = 16,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
@@ -501,11 +503,11 @@ class MochiPipeline(DiffusionPipeline):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to `self.default_height`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to `self.default_width`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `19`):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_frames (`int`, defaults to 16):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -575,6 +577,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
@@ -655,7 +658,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -697,21 +700,6 @@ class MochiPipeline(DiffusionPipeline):
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
|
||||
@@ -198,31 +198,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
def find_component(filename):
|
||||
if not len(filename.split("/")) == 2:
|
||||
return
|
||||
component = filename.split("/")[0]
|
||||
return component
|
||||
|
||||
def has_sharded_variant(component, variant, variant_filenames):
|
||||
# If component exists check for sharded variant index filename
|
||||
# If component doesn't exist check main dir for sharded variant index filename
|
||||
component = component + "/" if component else ""
|
||||
variant_index_re = re.compile(
|
||||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||
|
||||
for filename in non_variant_filenames:
|
||||
if convert_to_variant(filename) in variant_filenames:
|
||||
continue
|
||||
|
||||
component = find_component(filename)
|
||||
# If a sharded variant exists skip adding to allowed patterns
|
||||
if has_sharded_variant(component, variant, variant_filenames):
|
||||
continue
|
||||
|
||||
usable_filenames.add(filename)
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
@@ -43,9 +43,6 @@ def set_seed(seed: int):
|
||||
|
||||
Args:
|
||||
seed (`int`): The seed to set.
|
||||
|
||||
Returns:
|
||||
`None`
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -61,17 +58,6 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
"""
|
||||
Computes SNR as per
|
||||
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
for the given timesteps using the provided noise scheduler.
|
||||
|
||||
Args:
|
||||
noise_scheduler (`NoiseScheduler`):
|
||||
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
||||
the SNR values.
|
||||
timesteps (`torch.Tensor`):
|
||||
A tensor of timesteps for which the SNR is computed.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
@@ -393,7 +379,7 @@ class EMAModel:
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
||||
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
||||
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
||||
model = model_cls.from_pretrained(path)
|
||||
|
||||
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
||||
|
||||
@@ -2,21 +2,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -62,21 +47,6 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLAllegro(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2,21 +2,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AllegroPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ def get_cached_module_file(
|
||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
||||
shutil.copyfile(resolved_module_file, submodule_path / module_file)
|
||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||
for module_needed in modules_needed:
|
||||
if len(module_needed.split(".")) == 2:
|
||||
module_needed = "/".join(module_needed.split("."))
|
||||
@@ -333,7 +333,7 @@ def get_cached_module_file(
|
||||
if not os.path.exists(submodule_path / module_folder):
|
||||
os.makedirs(submodule_path / module_folder)
|
||||
module_needed = f"{module_needed}.py"
|
||||
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
||||
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
||||
else:
|
||||
# Get the commit hash
|
||||
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||
@@ -350,7 +350,7 @@ def get_cached_module_file(
|
||||
module_folder = module_file.split("/")[0]
|
||||
if not os.path.exists(submodule_path / module_folder):
|
||||
os.makedirs(submodule_path / module_folder)
|
||||
shutil.copyfile(resolved_module_file, submodule_path / module_file)
|
||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||
|
||||
# Make sure we also have every file with relative
|
||||
for module_needed in modules_needed:
|
||||
|
||||
@@ -57,7 +57,6 @@ _required_transformers_version = is_transformers_available() and version.parse(
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -311,26 +310,6 @@ def require_torch_accelerator_with_fp64(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_big_gpu_with_torch_cuda(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
|
||||
etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.loaders.lora_base import LoraBaseMixin
|
||||
|
||||
|
||||
class UtilityMethodDeprecationTests(unittest.TestCase):
|
||||
def test_fetch_state_dict_cls_method_raises_warning(self):
|
||||
state_dict = torch.nn.Linear(3, 3).state_dict()
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
_ = LoraBaseMixin._fetch_state_dict(
|
||||
state_dict,
|
||||
weight_name=None,
|
||||
use_safetensors=False,
|
||||
local_files_only=True,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
token=None,
|
||||
revision=None,
|
||||
subfolder=None,
|
||||
user_agent=None,
|
||||
allow_pickle=None,
|
||||
)
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `_fetch_state_dict()` method from" in warning_message
|
||||
|
||||
def test_best_guess_weight_name_cls_method_raises_warning(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
state_dict = torch.nn.Linear(3, 3).state_dict()
|
||||
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `_best_guess_weight_name()` method from" in warning_message
|
||||
@@ -166,14 +166,48 @@ class LoraSD3IntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_sd3_img2img_lora(self):
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2", weight_name="pytorch_lora_weights.safetensors")
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, -3:, -3:]
|
||||
expected_slice = np.array([0.5396, 0.5776, 0.7432, 0.5151, 0.5586, 0.7383, 0.5537, 0.5933, 0.7153])
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.47827148,
|
||||
0.5,
|
||||
0.71972656,
|
||||
0.3955078,
|
||||
0.4194336,
|
||||
0.69628906,
|
||||
0.37036133,
|
||||
0.40820312,
|
||||
0.6923828,
|
||||
0.36450195,
|
||||
0.40429688,
|
||||
0.6904297,
|
||||
0.35595703,
|
||||
0.39257812,
|
||||
0.68652344,
|
||||
0.35498047,
|
||||
0.3984375,
|
||||
0.68310547,
|
||||
0.34716797,
|
||||
0.3996582,
|
||||
0.6855469,
|
||||
0.3388672,
|
||||
0.3959961,
|
||||
0.6816406,
|
||||
0.34033203,
|
||||
0.40429688,
|
||||
0.6845703,
|
||||
0.34228516,
|
||||
0.4086914,
|
||||
0.6870117,
|
||||
]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
|
||||
|
||||
+2
-2
@@ -1787,7 +1787,7 @@ class PeftLoraLoaderMixinTests:
|
||||
logger = (
|
||||
logging.get_logger("diffusers.loaders.unet")
|
||||
if self.unet_kwargs is not None
|
||||
else logging.get_logger("diffusers.loaders.peft")
|
||||
else logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
)
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -1826,7 +1826,7 @@ class PeftLoraLoaderMixinTests:
|
||||
logger = (
|
||||
logging.get_logger("diffusers.loaders.unet")
|
||||
if self.unet_kwargs is not None
|
||||
else logging.get_logger("diffusers.loaders.peft")
|
||||
else logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
)
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
|
||||
@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -169,17 +170,52 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
@@ -293,11 +329,9 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@@ -330,20 +364,9 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DecoderTiny", "EncoderTiny"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip(
|
||||
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
|
||||
)
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ConsistencyDecoderVAE
|
||||
@@ -420,17 +443,55 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
|
||||
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
@@ -461,11 +522,9 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -58,7 +57,6 @@ from diffusers.utils.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -787,101 +785,6 @@ class ModelTesterMixin:
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
def test_gradient_checkpointing_is_applied(
|
||||
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
|
||||
):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
if self.model_class.__name__ in [
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
]:
|
||||
return
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
if attention_head_dim is not None:
|
||||
init_dict["attention_head_dim"] = attention_head_dim
|
||||
if num_attention_heads is not None:
|
||||
init_dict["num_attention_heads"] = num_attention_heads
|
||||
if block_out_channels is not None:
|
||||
init_dict["block_out_channels"] = block_out_channels
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||
|
||||
@@ -84,13 +84,6 @@ class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = Transformer2DModel.from_config(init_dict)
|
||||
assert isinstance(model, DiTTransformer2DModel)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DiTTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
|
||||
|
||||
def test_correct_class_remapping_from_pretrained_config(self):
|
||||
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
|
||||
model = Transformer2DModel.from_config(config)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user