Compare commits

..

52 Commits

Author SHA1 Message Date
yiyixuxu 44fb2fd6ae add 2025-01-14 06:13:45 +01:00
Sayak Paul 74b67524b5 [Docs] Update hunyuan_video.md to rectify the checkpoint id (#10524)
* Update hunyuan_video.md to rectify the checkpoint id

* bfloat16

* more fixes

* don't update the checkpoint ids.

* update

* t -> T

* Apply suggestions from code review

* fix

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-01-13 10:59:13 -10:00
Vinh H. Pham 794f7e49a9 Implement framewise encoding/decoding in LTX Video VAE (#10488)
* add framewise decode

* add framewise encode, refactor tiled encode/decode

* add sanity test tiling for ltx

* run make style

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

---------

Co-authored-by: Pham Hong Vinh <vinhph3@vng.com.vn>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
2025-01-13 10:58:32 -10:00
Daniel Regado 9fc9c6dd71 Added IP-Adapter for StableDiffusion3ControlNetInpaintingPipeline (#10561)
* Added support for IP-Adapter

* Fixed Copied inconsistency
2025-01-13 10:15:36 -10:00
Omar Awile df355ea2c6 Fix documentation for FluxPipeline (#10563)
Fix argument name in 8bit quantized example

Found a tiny mistake in the documentation where the text encoder model was passed to the wrong argument in the FluxPipeline.from_pretrained function.
2025-01-13 11:56:32 -08:00
Junsong Chen ae019da9e3 [Sana] add Sana to auto-text2image-pipeline; (#10538)
add Sana to auto-text2image-pipeline;
2025-01-13 09:54:37 -10:00
Sayak Paul 329771e542 [LoRA] improve failure handling for peft. (#10551)
* improve failure handling for peft.

* emppty

* Update src/diffusers/loaders/peft.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2025-01-13 09:20:49 -10:00
Dhruv Nair f7cb595428 [Single File] Fix loading Flux Dev finetunes with Comfy Prefix (#10545)
* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-13 21:25:07 +05:30
hlky c3478a42b9 Fix Nightly AudioLDM2PipelineFastTests (#10556)
* Fix Nightly AudioLDM2PipelineFastTests

* add phonemizer to setup extras test

* fix

* make style
2025-01-13 13:54:06 +00:00
hlky 980736b792 Fix train_dreambooth_lora_sd3_miniature (#10554) 2025-01-13 13:47:27 +00:00
hlky 50c81df4e7 Fix StableDiffusionInstructPix2PixPipelineSingleFileSlowTests (#10557) 2025-01-13 13:47:10 +00:00
Aryan e1c7269720 Fix Latte output_type (#10558)
update
2025-01-13 19:15:59 +05:30
Sayak Paul edb8c1bce6 [Flux] Improve true cfg condition (#10539)
* improve flux true cfg condition

* add test
2025-01-12 18:33:34 +05:30
Sayak Paul 0785dba4df [Docs] Add negative prompt docs to FluxPipeline (#10531)
* add negative_prompt documentation.

* add proper docs for negative prompts

* fix-copies

* remove comment.

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-01-12 18:02:46 +05:30
Muyang Li 5cda8ea521 Use randn_tensor to replace torch.randn (#10535)
`torch.randn` requires `generator` and `latents` on the same device, while the wrapped function `randn_tensor` does not have this issue.
2025-01-12 11:41:41 +05:30
Sayak Paul 36acdd7517 [Tests] skip tests properly with unittest.skip() (#10527)
* skip tests properly.

* more

* more
2025-01-11 08:46:22 +05:30
Junyu Chen e7db062e10 [DC-AE] support tiling for DC-AE (#10510)
* autoencoder_dc tiling

* add tiling and slicing support in SANA pipelines

* create variables for padding length because the line becomes too long

* add tiling and slicing support in pag SANA pipelines

* revert changes to tile size

* make style

* add vae tiling test

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-11 07:15:26 +05:30
andreabosisio 1b0fe63656 Typo fix in the table number of a referenced paper (#10528)
Correcting a typo in the table number of a referenced paper (in scheduling_ddim_inverse.py)

Changed the number of the referenced table from 1 to 2 in a comment of the set_timesteps() method of the DDIMInverseScheduler class (also according to the description of the 'timestep_spacing' attribute of its __init__ method).
2025-01-10 17:15:25 -08:00
chaowenguo d6c030fd37 add the xm.mark_step for the first denosing loop (#10530)
* Update rerender_a_video.py

* Update rerender_a_video.py

* Update examples/community/rerender_a_video.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update rerender_a_video.py

* make style

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-01-10 21:03:41 +00:00
Sayak Paul 9f06a0d1a4 [CI] Match remaining assertions from big runner (#10521)
* print

* remove print.

* print

* update slice.

* empty
2025-01-10 16:37:36 +05:30
Daniel Hipke 52c05bd4cd Add a disable_mmap option to the from_single_file loader to improve load performance on network mounts (#10305)
* Add no_mmap arg.

* Fix arg parsing.

* Update another method to force no mmap.

* logging

* logging2

* propagate no_mmap

* logging3

* propagate no_mmap

* logging4

* fix open call

* clean up logging

* cleanup

* fix missing arg

* update logging and comments

* Rename to disable_mmap and update other references.

* [Docs] Update ltx_video.md to remove generator from `from_pretrained()` (#10316)

Update ltx_video.md to remove generator from `from_pretrained()`

* docs: fix a mistake in docstring (#10319)

Update pipeline_hunyuan_video.py

docs: fix a mistake

* [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)

[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <hlky@hlky.ac>

* [docs] Fix quantization links (#10323)

Update overview.md

* [Sana]add 2K related model for Sana (#10322)

add 2K related model for Sana

* Update src/diffusers/loaders/single_file_model.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/loaders/single_file.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* make style

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leojc <liao_junchao@outlook.com>
Co-authored-by: Aditya Raj <syntaxticsugr@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Junsong Chen <cjs1020440147@icloud.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-01-10 15:41:04 +05:30
Sayak Paul a6f043a80f [LoRA] allow big CUDA tests to run properly for LoRA (and others) (#9845)
* allow big lora tests to run on the CI.

* print

* print.

* print

* print

* print

* print

* more

* print

* remove print.

* remove print

* directly place on cuda.

* remove pipeline.

* remove

* fix

* fix

* spaces

* quality

* updates

* directly place flux controlnet pipeline on cuda.

* torch_device instead of cuda.

* style

* device placement.

* fixes

* add big gpu marker for mochi; rename test correctly

* address feedback

* fix

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-10 12:50:24 +05:30
hlky 12fbe3f7dc Use Pipelines without unet (#10440)
* Use Pipelines without unet

* unet.config.in_channels

* default_sample_size

* is_unet_version_less_0_9_0

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-10 04:45:42 +00:00
Linoy Tsaban 83ba01a38d small readme changes for advanced training examples (#10473)
add to readme about hf login and wandb installation to address https://github.com/huggingface/diffusers/issues/10142#issuecomment-2571655570

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-10 07:35:19 +05:30
Zehuan Huang 7116fd24e5 Support pass kwargs to cogvideox custom attention processor (#10456)
* Support pass kwargs to cogvideox custom attention processor

* remove args in cogvideox attn processor

* remove unused kwargs
2025-01-09 11:57:03 -10:00
Sayak Paul 553b13845f [LoRA] clean up load_lora_into_text_encoder() and fuse_lora() copied from (#10495)
* factor out text encoder loading.

* make fix-copies

* remove copied from fuse_lora and unfuse_lora as needed.

* remove unused imports
2025-01-09 11:29:16 -10:00
chaowenguo 7bc8b92384 add callable object to convert frame into control_frame to reduce cpu memory usage. (#10501)
* Update rerender_a_video.py

* Update rerender_a_video.py

* Update examples/community/rerender_a_video.py

Co-authored-by: hlky <hlky@hlky.ac>

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-01-09 11:25:53 -10:00
Vladimir Mandic f0c6d9784b flux: make scheduler config params optional (#10384)
* dont assume scheduler has optional config params

* make style, make fix-copies

* calculate_shift

* fix-copies, usage in pipelines

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-01-09 10:44:26 -10:00
Steven Liu d006f0769b [docs] Fix missing parameters in docstrings (#10419)
* fix docstrings

* add
2025-01-09 10:54:39 -08:00
geronimi73 a26d57097a AutoModel instead of AutoModelForCausalLM (#10507) 2025-01-09 16:28:04 +05:30
Sayak Paul daf9d0f119 [chore] remove prints from tests. (#10505)
remove prints from tests.
2025-01-09 14:19:43 +05:30
hlky 95c5ce4e6f PyTorch/XLA support (#10498)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-08 12:31:27 -10:00
Junsong Chen c0964571fc [Sana 4K] (#10493)
add 4K support for Sana
2025-01-08 11:58:11 -10:00
hlky b13cdbb294 UNet2DModel mid_block_type (#10469) 2025-01-08 10:50:29 -10:00
Bagheera a0acbdc989 fix for #7365, prevent pipelines from overriding provided prompt embeds (#7926)
* fix for #7365, prevent pipelines from overriding provided prompt embeds

* fix-copies

* fix implementation

* update

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2025-01-08 10:12:12 -10:00
Parag Ekbote 5655b22ead Notebooks for Community Scripts-5 (#10499)
Add 5 Notebooks for Diffusers Community
Pipelines.
2025-01-08 08:56:17 -08:00
hlky 4df9d49218 Fix tokenizers install from main in LoRA tests (#10494)
* Fix tokenizers install from main in LoRA tests

* @

* rust

* -e

* uv

* just update tokenizers
2025-01-08 16:14:25 +00:00
Dhruv Nair 9731773d39 [CI] Torch Min Version Test Fix (#10491)
update
2025-01-08 19:43:38 +05:30
Marc Sun e2deb82e69 Fix compatibility with pipeline when loading model with device_map on single gpu (#10390)
* fix device issue in single gpu case

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-08 11:35:00 +01:00
hlky 1288c8560a Update tokenizers in pr_test_peft_backend (#10132)
Update tokenizers
2025-01-08 10:09:32 +00:00
AstraliteHeart cb342b745a Add AuraFlow GGUF support (#10463)
* Add support for loading AuraFlow models from GGUF

https://huggingface.co/city96/AuraFlow-v0.3-gguf

* Update AuraFlow documentation for GGUF, add GGUF tests and model detection.

* Address code review comments.

* Remove unused config.

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-01-08 13:23:12 +05:30
Junsong Chen 80fd9260bb [Sana][bug fix]change clean_caption from True to False. (#10481)
change clean_caption from True to False.
2025-01-07 15:31:23 -10:00
Aryan 71ad16b463 Add _no_split_modules to some models (#10308)
* set supports gradient checkpointing to true where necessary; add missing no split modules

* fix cogvideox tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-01-08 06:34:19 +05:30
hlky ee7e141d80 Use pipelines without vae (#10441)
* Use pipelines without vae

* getattr

* vqvae

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-07 13:26:51 -10:00
hlky 01bd79649e Fix HunyuanVideo produces NaN on PyTorch<2.5 (#10482)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-07 13:13:55 -10:00
Teriks 03bcf5aefe RFInversionFluxPipeline, small fix for enable_model_cpu_offload & enable_sequential_cpu_offload compatibility (#10480)
RFInversionFluxPipeline.encode_image, device fix

Use self._execution_device instead of self.device when selecting
a device for the input image tensor.

This allows for compatibility with enable_model_cpu_offload &
enable_sequential_cpu_offload

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-01-07 15:47:28 +01:00
dependabot[bot] e0b96ba7b0 Bump jinja2 from 3.1.4 to 3.1.5 in /examples/research_projects/realfill (#10377)
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-07 19:59:41 +05:30
Dhruv Nair 854a04659c [CI] Add minimal testing for legacy Torch versions (#10479)
* update

* update
2025-01-07 18:51:41 +05:30
hlky 628f2c544a Use Pipelines without scheduler (#10439)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-07 12:07:08 +00:00
Aryan 811560b1d7 [LoRA] Support original format loras for HunyuanVideo (#10376)
* update

* fix make copies

* update

* add relevant markers to the integration test suite.

* add copied.

* fox-copies

* temporarily add print.

* directly place on CUDA as CPU isn't that big on the CIO.

* fixes to fuse_lora, aryan was right.

* fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-07 13:18:57 +05:30
Rahul Raman f1e0c7ce4a Refactor instructpix2pix lora to support peft (#10205)
* make base code changes referred from train_instructpix2pix script in examples

* change code to use PEFT as discussed in issue 10062

* update README training command

* update README training command

* refactor variable name and freezing unet

* Update examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update README installation instructions.

* cleanup code using make style and quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-07 12:00:45 +05:30
Sayak Paul b94cfd7937 [Training] QoL improvements in the Flux Control training scripts (#10461)
* qol improvements to the Flux script.

* propagate the dataloader changes.
2025-01-07 11:56:17 +05:30
322 changed files with 4633 additions and 1750 deletions
+1 -1
View File
@@ -272,7 +272,7 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modelling_common.py \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
tests/pipelines/test_pipeline_utils.py \
tests/pipelines/test_pipelines.py \
+1
View File
@@ -266,6 +266,7 @@ jobs:
# TODO (sayakpaul, DN6): revisit `--no-deps`
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install -U tokenizers
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
+1 -1
View File
@@ -193,7 +193,7 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modelling_common.py \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
tests/pipelines/test_pipeline_utils.py \
tests/pipelines/test_pipelines.py \
+27
View File
@@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
```py
import torch
from diffusers import (
AuraFlowPipeline,
GGUFQuantizationConfig,
AuraFlowTransformer2DModel,
)
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
prompt = "a cute pony in a field of flowers"
image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
## AuraFlowPipeline
[[autodoc]] AuraFlowPipeline
+1 -1
View File
@@ -367,7 +367,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=text_encoder_8bit,
text_encoder_2=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
@@ -16,7 +16,7 @@
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*
<Tip>
@@ -45,14 +45,14 @@ from diffusers.utils import export_to_video
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
)
pipeline = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
+2 -2
View File
@@ -59,10 +59,10 @@ Refer to the [Quantization](../../quantization/overview) overview to learn more
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
@@ -78,10 +78,10 @@ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
"hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
)
# reduce memory requirements
@@ -67,6 +67,17 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Pivotal Tuning
**Training with text encoder(s)**
@@ -65,6 +65,17 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+5 -5
View File
@@ -33,12 +33,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
@@ -50,7 +50,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
+12 -8
View File
@@ -372,7 +372,7 @@ class AdaptiveMaskInpaintPipeline(
self.register_adaptive_mask_model()
self.register_adaptive_mask_settings()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -386,7 +386,7 @@ class AdaptiveMaskInpaintPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -416,10 +416,14 @@ class AdaptiveMaskInpaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -438,7 +442,7 @@ class AdaptiveMaskInpaintPipeline(
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 9:
if unet is not None and unet.config.in_channels != 9:
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules(
@@ -450,7 +454,7 @@ class AdaptiveMaskInpaintPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -89,7 +89,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -103,7 +103,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -132,10 +132,14 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -162,7 +166,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+1 -1
View File
@@ -35,7 +35,7 @@ class EDICTPipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
+1 -1
View File
@@ -1342,7 +1342,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+1 -1
View File
@@ -221,7 +221,7 @@ class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, St
language_adapter=language_adapter,
tensor_norm=tensor_norm,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+1 -1
View File
@@ -95,7 +95,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+11 -7
View File
@@ -109,7 +109,7 @@ class InstaFlowPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +123,7 @@ class InstaFlowPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -152,10 +152,14 @@ class InstaFlowPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -182,7 +186,7 @@ class InstaFlowPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -86,7 +86,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+11 -7
View File
@@ -191,7 +191,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -205,7 +205,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -234,10 +234,14 @@ class IPAdapterFaceIDStableDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -265,7 +269,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+1 -1
View File
@@ -463,6 +463,6 @@ class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -69,7 +69,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
@@ -273,7 +273,7 @@ class LatentConsistencyModelWalkPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -67,7 +67,7 @@ class LatentConsistencyModelPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
+11 -7
View File
@@ -336,7 +336,7 @@ class LLMGroundedDiffusionPipeline(
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -350,7 +350,7 @@ class LLMGroundedDiffusionPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -379,10 +379,14 @@ class LLMGroundedDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -410,7 +414,7 @@ class LLMGroundedDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+11 -7
View File
@@ -496,7 +496,7 @@ class StableDiffusionLongPromptWeightingPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -510,7 +510,7 @@ class StableDiffusionLongPromptWeightingPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -539,10 +539,14 @@ class StableDiffusionLongPromptWeightingPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -568,7 +572,7 @@ class StableDiffusionLongPromptWeightingPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
+11 -4
View File
@@ -673,12 +673,16 @@ class SDXLLongPromptWeightingPipeline(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -827,7 +831,9 @@ class SDXLLongPromptWeightingPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -879,7 +885,8 @@ class SDXLLongPromptWeightingPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+10 -6
View File
@@ -3766,7 +3766,7 @@ class MatryoshkaPipeline(
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -3780,7 +3780,7 @@ class MatryoshkaPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
# if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
# deprecation_message = (
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -3793,10 +3793,14 @@ class MatryoshkaPipeline(
# new_config["clip_sample"] = False
# scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -98,7 +98,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -188,7 +188,7 @@ class AnimateDiffControlNetPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -308,7 +308,7 @@ class AnimateDiffImgToVideoPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -162,7 +162,7 @@ class AnimateDiffPipelineIpex(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+11 -4
View File
@@ -166,9 +166,13 @@ class DemoFusionSDXLPipeline(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -290,7 +294,9 @@ class DemoFusionSDXLPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -342,7 +348,8 @@ class DemoFusionSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+9 -5
View File
@@ -150,10 +150,14 @@ class FabricPipeline(DiffusionPipeline):
):
super().__init__()
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -179,7 +183,7 @@ class FabricPipeline(DiffusionPipeline):
tokenizer=tokenizer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -221,13 +221,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=self.vae.config.latent_channels,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
@@ -876,10 +875,10 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -219,9 +219,7 @@ class RFInversionFluxPipeline(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -419,7 +417,7 @@ class RFInversionFluxPipeline(
)
image = image.to(dtype)
x0 = self.vae.encode(image.to(self.device)).latent_dist.sample()
x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()
x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
x0 = x0.to(dtype)
return x0, resized
@@ -822,10 +820,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -992,10 +990,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inversion_steps = retrieve_timesteps(
self.scheduler,
+6 -7
View File
@@ -64,6 +64,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -189,9 +190,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -757,10 +756,10 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -327,9 +327,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -209,16 +209,18 @@ class KolorsDifferentialImg2ImgPipeline(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt(
+11 -7
View File
@@ -131,7 +131,7 @@ class Prompt2PromptPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -145,7 +145,7 @@ class Prompt2PromptPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -174,10 +174,14 @@ class Prompt2PromptPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -205,7 +209,7 @@ class Prompt2PromptPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -488,13 +488,17 @@ class StyleAlignedSDXLPipeline(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -628,7 +632,9 @@ class StyleAlignedSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -688,7 +694,8 @@ class StyleAlignedSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -207,7 +207,7 @@ class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
)
@@ -417,7 +417,7 @@ class StableDiffusionBoxDiffPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -431,7 +431,7 @@ class StableDiffusionBoxDiffPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -460,10 +460,14 @@ class StableDiffusionBoxDiffPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -491,7 +495,7 @@ class StableDiffusionBoxDiffPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -384,7 +384,7 @@ class StableDiffusionPAGPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -398,7 +398,7 @@ class StableDiffusionPAGPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -427,10 +427,14 @@ class StableDiffusionPAGPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -458,7 +462,7 @@ class StableDiffusionPAGPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -151,7 +151,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
# self.register_to_config(requires_safety_checker=requires_safety_checker)
self.register_to_config(max_noise_level=max_noise_level)
@@ -226,12 +226,16 @@ class StableDiffusionXLControlNetAdapterPipeline(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -359,7 +363,9 @@ class StableDiffusionXLControlNetAdapterPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -419,7 +425,8 @@ class StableDiffusionXLControlNetAdapterPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -374,12 +374,16 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -507,7 +511,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -567,7 +573,8 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -258,7 +258,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -394,7 +394,9 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -454,7 +456,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -253,10 +253,14 @@ class StableDiffusionXLPipelineIpex(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -390,7 +394,9 @@ class StableDiffusionXLPipelineIpex(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -450,7 +456,8 @@ class StableDiffusionXLPipelineIpex(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+11 -7
View File
@@ -108,7 +108,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +122,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -151,10 +151,14 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -181,7 +185,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
feature_extractor=feature_extractor,
cc_projection=cc_projection,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# self.model_mode = None
+7 -4
View File
@@ -352,7 +352,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -632,7 +632,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
strength ('float'): SDEdit strength.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -789,7 +789,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
# Currently we only support single control
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_frames[0],
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
width=width,
height=height,
batch_size=batch_size,
@@ -908,6 +908,9 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
@@ -924,7 +927,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
for idx in range(1, len(frames)):
image = frames[idx]
prev_image = frames[idx - 1]
control_image = control_frames[idx]
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
@@ -179,7 +179,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
@@ -278,7 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusio
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
@@ -263,7 +263,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
+11 -7
View File
@@ -105,7 +105,7 @@ class StableDiffusionIPEXPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -119,7 +119,7 @@ class StableDiffusionIPEXPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -148,10 +148,14 @@ class StableDiffusionIPEXPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -178,7 +182,7 @@ class StableDiffusionIPEXPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
+1 -1
View File
@@ -66,7 +66,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -132,7 +132,7 @@ class StableDiffusionReferencePipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -146,7 +146,7 @@ class StableDiffusionReferencePipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -181,10 +181,14 @@ class StableDiffusionReferencePipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -202,7 +206,7 @@ class StableDiffusionReferencePipeline(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 4:
if unet is not None and unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -219,7 +223,7 @@ class StableDiffusionReferencePipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+12 -8
View File
@@ -187,7 +187,7 @@ class StableDiffusionRepaintPipeline(
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -201,7 +201,7 @@ class StableDiffusionRepaintPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -236,10 +236,14 @@ class StableDiffusionRepaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -257,7 +261,7 @@ class StableDiffusionRepaintPipeline(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 4:
if unet is not None and unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -274,7 +278,7 @@ class StableDiffusionRepaintPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -710,7 +710,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -724,7 +724,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -753,10 +753,14 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -806,7 +810,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -714,7 +714,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -728,7 +728,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -757,10 +757,14 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -810,7 +814,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -626,7 +626,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -640,7 +640,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -669,10 +669,14 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -722,7 +726,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+2 -2
View File
@@ -71,7 +71,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -85,7 +85,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -120,7 +120,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+3 -3
View File
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
@@ -200,5 +200,5 @@ gen_images.save("output.png")
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
+51 -9
View File
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# control-lora-{repo_id}
# flux-control-{repo_id}
These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -434,7 +433,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -442,6 +441,7 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
default=None,
help="Path to the jsonl file containing the training data.",
)
parser.add_argument(
"--only_target_transformer_blocks",
action="store_true",
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
)
parser.add_argument(
"--guidance_scale",
type=float,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)
return args
@@ -665,7 +669,12 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["captions"] = list(examples[args.caption_column])
is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])
return examples
@@ -765,7 +774,8 @@ def main(args):
subfolder="scheduler",
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
flux_transformer.requires_grad_(True)
if not args.only_target_transformer_blocks:
flux_transformer.requires_grad_(True)
vae.requires_grad_(False)
# cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.only_target_transformer_blocks:
flux_transformer.x_embedder.requires_grad_(True)
for name, module in flux_transformer.named_modules():
if "transformer_blocks" in name:
module.requires_grad_(True)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
@@ -974,6 +990,32 @@ def main(args):
else:
initial_global_step = 0
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]
if len(formatted_images) > 10:
break
for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)
logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# controlnet-lora-{repo_id}
# control-lora-{repo_id}
These are Control LoRA weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-lora",
default="control-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
@@ -466,7 +465,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -474,6 +473,7 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)
return args
@@ -697,7 +697,12 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["captions"] = list(examples[args.caption_column])
is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])
return examples
@@ -1132,6 +1137,32 @@ def main(args):
else:
initial_global_step = 0
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]
if len(formatted_images) > 10:
break
for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)
logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
@@ -2,6 +2,34 @@
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
## Training script example
```bash
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
export DATASET_ID="instruction-tuning-sd/cartoonization"
export OUTPUT_DIR="instructPix2Pix-cartoonization"
accelerate launch finetune_instruct_pix2pix.py \
accelerate launch train_instruct_pix2pix_lora.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
--rank=4 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
--push_to_hub
--push_to_hub \
--original_image_column="original_image" \
--edited_image_column="cartoonized_image" \
--edit_prompt_column="edit_prompt"
```
## Inference
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
"""
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
import logging
@@ -30,6 +33,7 @@ import numpy as np
import PIL
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
@@ -39,21 +43,28 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0")
check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -63,6 +74,92 @@ DATASET_NAME_MAPPING = {
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
def save_model_card(
repo_id: str,
images: list = None,
base_model: str = None,
dataset_name: str = None,
repo_folder: str = None,
):
img_str = ""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion",
"stable-diffusion-diffusers",
"text-to-image",
"instruct-pix2pix",
"diffusers",
"diffusers-training",
"lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
generator,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
tracker.log({"validation": wandb_table})
return edited_images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -417,11 +514,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -467,49 +559,58 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero.
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
in_channels = 8
out_channels = unet.conv_in.out_channels
unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
unet.conv_in = new_conv_in
# Freeze vae, text_encoder and unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
# Freeze the unet parameters before adding adapters
unet.requires_grad_(False)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)
# Create EMA for the unet.
if args.use_ema:
@@ -528,6 +629,13 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -540,7 +648,8 @@ def main():
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -589,9 +698,9 @@ def main():
else:
optimizer_cls = torch.optim.AdamW
# train on only unet_lora_parameters
# train on only lora_layers
optimizer = optimizer_cls(
unet_lora_parameters,
trainable_params,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -730,22 +839,27 @@ def main():
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
@@ -765,8 +879,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -885,7 +1005,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
@@ -895,7 +1015,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -903,7 +1023,7 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet_lora_parameters)
ema_unet.step(trainable_params)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -933,6 +1053,16 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -959,45 +1089,22 @@ def main():
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
log_validation(
pipeline,
args,
accelerator,
generator,
)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1008,22 +1115,47 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
# store only LORA layers
unet = unet.to(torch.float32)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unet,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
# store only LORA layers
unet.save_attn_procs(args.output_dir)
pipeline.load_lora_weights(args.output_dir)
images = None
if (args.val_image_url is not None) and (args.validation_prompt is not None):
images = log_validation(
pipeline,
args,
accelerator,
generator,
)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
@@ -1031,31 +1163,6 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
accelerator.end_training()
@@ -310,7 +310,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
controlnet=controlnet,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -233,7 +233,7 @@ class PromptDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -78,7 +78,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
feature_extractor=feature_extractor,
)
# Copy from statement here and all the methods we take from stable_diffusion_pipeline
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.retriever = retriever
@@ -6,4 +6,4 @@ torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.4
Jinja2==3.1.5
@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+8 -4
View File
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
@@ -89,7 +90,10 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 3.0
if args.image_size == 4096:
flow_shift = 6.0
else:
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
@@ -99,7 +103,7 @@ def main(args):
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -272,9 +276,9 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048],
choices=[512, 1024, 2048, 4096],
required=False,
help="Image size of pretrained model, 512, 1024 or 2048.",
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
+2
View File
@@ -135,6 +135,7 @@ _deps = [
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
"phonemizer",
]
# this is a lookup table with items like:
@@ -227,6 +228,7 @@ extras["test"] = deps_list(
"scipy",
"torchvision",
"transformers",
"phonemizer",
)
extras["torch"] = deps_list("torch", "accelerate")
@@ -43,4 +43,5 @@ deps = {
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
}
+156 -21
View File
@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
@@ -43,6 +50,8 @@ from ..utils import (
if is_transformers_available():
from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
return weight_name
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
@@ -327,27 +482,7 @@ class LoraBaseMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
@@ -973,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
linear1_weight = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
linear1_bias = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
# sure that both follow the same initial format by stripping off the "transformer." prefix.
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
if key.startswith("diffusion_model."):
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
# Rename and remap the state dict keys
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
# Add back the "transformer." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+79 -612
View File
@@ -20,22 +20,24 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
scale_lora_layers,
)
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_base import ( # noqa
LORA_WEIGHT_NAME,
LORA_WEIGHT_NAME_SAFE,
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
@@ -54,9 +56,6 @@ if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
if is_transformers_available():
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
@@ -348,119 +347,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def save_lora_weights(
@@ -891,119 +788,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def save_lora_weights(
@@ -1400,119 +1195,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def save_lora_weights(
@@ -2032,119 +1725,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
@@ -2203,7 +1794,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2597,119 +2188,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def save_lora_weights(
@@ -3007,10 +2496,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3051,8 +2539,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3066,9 +2553,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3315,10 +2799,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3359,8 +2842,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3374,9 +2856,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3623,10 +3102,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3667,8 +3145,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3682,9 +3159,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3931,10 +3405,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3975,8 +3448,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3990,9 +3462,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -4007,7 +3476,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4018,7 +3486,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
We support loading original format HunyuanVideo LoRA checkpoints.
This function is experimental and might change in the future.
@@ -4101,6 +3569,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
if is_original_hunyuan_video:
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@@ -4239,10 +3711,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -4283,8 +3754,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -4298,9 +3768,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
+12 -35
View File
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from ..utils import (
MIN_PEFT_VERSION,
@@ -30,20 +29,16 @@ from ..utils import (
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_base import _fetch_state_dict
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
@@ -140,27 +135,7 @@ class PeftAdapterMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
@@ -325,15 +300,17 @@ class PeftAdapterMixin:
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except RuntimeError as e:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
+8
View File
@@ -60,6 +60,7 @@ def load_single_file_sub_model(
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
@@ -106,6 +107,7 @@ def load_single_file_sub_model(
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)
@@ -308,6 +310,9 @@ class FromSingleFileMixin:
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -355,6 +360,7 @@ class FromSingleFileMixin:
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
@@ -383,6 +389,7 @@ class FromSingleFileMixin:
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if config is None:
@@ -504,6 +511,7 @@ class FromSingleFileMixin:
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
@@ -25,6 +25,7 @@ from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
@@ -106,6 +107,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -182,6 +187,9 @@ class FromOriginalModelMixin:
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -229,6 +237,7 @@ class FromOriginalModelMixin:
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
@@ -241,6 +250,7 @@ class FromOriginalModelMixin:
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
+113 -4
View File
@@ -94,6 +94,12 @@ CHECKPOINT_KEY_NAMES = {
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"auraflow": [
"double_layers.0.attn.w2q.weight",
"double_layers.0.attn.w1q.weight",
"cond_seq_linear.weight",
"t_embedder.mlp.0.weight",
],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -154,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -179,6 +186,7 @@ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
"instruct-pix2pix": 512,
"v2": 768,
"v1": 512,
}
@@ -380,6 +388,7 @@ def load_single_file_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
disable_mmap=False,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
@@ -397,7 +406,7 @@ def load_single_file_checkpoint(
revision=revision,
)
checkpoint = load_state_dict(pretrained_model_link_or_path)
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
@@ -597,10 +606,14 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
if "model.diffusion_model.img_in.weight" in checkpoint:
key = "model.diffusion_model.img_in.weight"
else:
key = "img_in.weight"
elif checkpoint["img_in.weight"].shape[1] == 128:
if checkpoint[key].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint[key].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
@@ -635,6 +648,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
model_type = "auraflow"
elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
@@ -2090,6 +2106,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2689,3 +2706,95 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, checkpoint)
return checkpoint
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
state_dict_keys = list(checkpoint.keys())
# Handle register tokens and positional embeddings
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
# Handle time step projection
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
# Handle context embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
# Calculate the number of layers
def calculate_layers(keys, key_prefix):
layers = set()
for k in keys:
if key_prefix in k:
layer_num = int(k.split(".")[1]) # get the layer number
layers.add(layer_num)
return len(layers)
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
# MMDiT blocks
for i in range(mmdit_layers):
# Feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.{k}.weight", None
)
# Norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.1.weight", None
)
# Attentions
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.attn.{k}.weight", None
)
# Single-DiT blocks
for i in range(single_dit_layers):
# Feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.mlp.{k}.weight", None
)
# Norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"single_layers.{i}.modCX.1.weight", None
)
# Attentions
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.attn.{k}.weight", None
)
# Final blocks
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
# Handle the final norm layer
norm_weight = checkpoint.pop("modF.1.weight", None)
if norm_weight is not None:
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
else:
converted_state_dict["norm_out.linear.weight"] = None
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
return converted_state_dict
+2 -25
View File
@@ -21,7 +21,6 @@ import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import (
ImageProjection,
@@ -44,13 +43,11 @@ from ..utils import (
is_torch_version,
logging,
)
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
@@ -411,27 +408,7 @@ class UNet2DConditionLoadersMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
self,
@@ -486,6 +486,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -515,6 +518,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
@@ -606,11 +611,106 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, x.shape[2], self.tile_sample_stride_height):
row = []
for j in range(0, x.shape[3], self.tile_sample_stride_width):
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
if (
tile.shape[2] % self.spatial_compression_ratio != 0
or tile.shape[3] % self.spatial_compression_ratio != 0
):
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
tile = F.pad(tile, (0, pad_w, 0, pad_h))
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=3))
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
if not return_dict:
return (encoded,)
return EncoderOutput(latent=encoded)
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=3))
decoded = torch.cat(result_rows, dim=2)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
encoded = self.encode(sample, return_dict=False)[0]
@@ -1010,10 +1010,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
@@ -1023,8 +1025,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1046,8 +1050,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
@@ -1073,18 +1079,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
enc = self.encoder(x)
return enc
@@ -1121,19 +1122,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z, temb)
dec = self.decoder(z, temb)
if not return_dict:
return (dec,)
@@ -1189,6 +1186,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
@@ -1217,17 +1222,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
row.append(time)
rows.append(row)
@@ -1283,17 +1280,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
row.append(time)
rows.append(row)
@@ -1318,6 +1305,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc
def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
@@ -1334,5 +1389,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.mode()
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return (dec.sample,)
return dec
+7 -2
View File
@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
+7 -7
View File
@@ -559,6 +559,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
<Tip>
@@ -604,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
disable_mmap = kwargs.pop("disable_mmap", False)
allow_pickle = False
if use_safetensors is None:
@@ -883,7 +887,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
@@ -920,14 +924,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
force_hook = True
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
)
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
force_hook = False
try:
accelerate.load_checkpoint_and_dispatch(
model,
@@ -937,7 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
strict=True,
)
except AttributeError as e:
@@ -967,7 +968,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
@@ -983,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
@@ -1214,7 +1214,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -120,8 +120,10 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -133,6 +135,7 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -210,6 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@register_to_config
def __init__(
@@ -497,6 +501,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states,
emb,
image_rotary_emb,
attention_kwargs,
**ckpt_kwargs,
)
else:
@@ -505,6 +510,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
)
if not self.config.use_rotary_positional_embeddings:
@@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
@@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
@@ -713,15 +719,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads
attention_mask[i, : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
+20 -15
View File
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
@@ -194,19 +195,22 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
if mid_block_type is None:
self.mid_block = None
else:
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
@@ -322,7 +326,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
@@ -33,6 +33,7 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
from .pipeline_output import AllegroPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
if is_bs4_available():
@@ -194,10 +203,10 @@ class AllegroPipeline(DiffusionPipeline):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -66,7 +74,9 @@ class AmusedPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -297,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -81,7 +89,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -323,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -89,7 +97,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -354,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -34,6 +34,7 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -139,7 +148,7 @@ class AnimateDiffPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents
@@ -32,7 +32,7 @@ from ...models import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -180,7 +188,7 @@ class AnimateDiffControlNetPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents

Some files were not shown because too many files have changed in this diff Show More