Compare commits

..

70 Commits

Author SHA1 Message Date
Sayak Paul 6bf668c4d2 [chore] remove torch.save from remnant code. (#12717)
remove torch.save from remnant code.
2025-11-27 13:04:09 +05:30
Jerry Wu e6d4612309 Support unittest for Z-image ️ (#12715)
* Add Support for Z-Image.

* Reformatting with make style, black & isort.

* Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.

* modified main model forward, freqs_cis left

* refactored to add B dim

* fixed stack issue

* fixed modulation bug

* fixed modulation bug

* fix bug

* remove value_from_time_aware_config

* styling

* Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.

* Replace padding with pad_sequence; Add gradient checkpointing.

* Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

* Fix Docstring and Make Style.

* Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7ed1.

* update z-image docstring

* Revert attention dispatcher

* update z-image docstring

* styling

* Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.

* Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.

* Remove einop dependency.

* remove redundant imports & make fix-copies

* fix import

* Support for num_images_per_prompt>1; Remove redundant unquote variables.

* Fix bugs for num_images_per_prompt with actual batch.

* Add unit tests for Z-Image.

* Refine unitest and skip for cases needed separate test env; Fix compatibility with unitest in model, mostly precision formating.

* Add clean env for test_save_load_float16 separ test; Add Note; Styling.

* Update dtype mentioned by yiyi.

---------

Co-authored-by: liudongyang <liudongyang0114@gmail.com>
2025-11-26 07:18:57 -10:00
David El Malih a88a7b4f03 Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)
* Improve docstrings and type hints in multiple diffusion schedulers

* docs: update Imagen Video paper link to Hugging Face Papers.
2025-11-26 08:38:41 -08:00
Sayak Paul c8656ed73c [docs] put autopipeline after overview and hunyuanimage in images (#12548)
put autopipeline after overview and hunyuanimage in images
2025-11-26 15:34:22 +05:30
Sayak Paul 94c9613f99 [docs] Correct flux2 links (#12716)
* fix links

* up
2025-11-26 10:46:51 +05:30
Sayak Paul b91e8c0d0b [lora]: Fix Flux2 LoRA NaN test (#12714)
* up

* Update tests/lora/test_lora_layers_flux2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-11-26 09:07:48 +05:30
Andrei Filatov ac7864624b Update script names in README for Flux2 training (#12713) 2025-11-26 07:02:18 +05:30
Sayak Paul 5ffb73d4ae let's go Flux2 🚀 (#12711)
* add vae

* Initial commit for Flux 2 Transformer implementation

* add pipeline part

* small edits to the pipeline and conversion

* update conversion script

* fix

* up up

* finish pipeline

* Remove Flux IP Adapter logic for now

* Remove deprecated 3D id logic

* Remove ControlNet logic for now

* Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block

* update pipeline

* Don't use biases for input projs and output AdaNorm

* up

* Remove bias for double stream block text QKV projections

* Add script to convert Flux 2 transformer to diffusers

* make style and make quality

* fix a few things.

* allow sft files to go.

* fix image processor

* fix batch

* style a bit

* Fix some bugs in Flux 2 transformer implementation

* Fix dummy input preparation and fix some test bugs

* fix dtype casting in timestep guidance module.

* resolve conflicts.,

* remove ip adapter stuff.

* Fix Flux 2 transformer consistency test

* Fix bug in Flux2TransformerBlock (double stream block)

* Get remaining Flux 2 transformer tests passing

* make style; make quality; make fix-copies

* remove stuff.

* fix type annotaton.

* remove unneeded stuff from tests

* tests

* up

* up

* add sf support

* Remove unused IP Adapter and ControlNet logic from transformer (#9)

* copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>

* up

* up

* up

* up

* up

* Refactor Flux2Attention into separate classes for double stream and single stream attention

* Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion

* Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False

* Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion

* Address review comments

* Update src/diffusers/pipelines/flux2/pipeline_flux2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* up

* Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12)

* up

* support ostris loras. (#13)

* up

* update schdule

* up

* up (#17)

* add training scripts (#16)

* add training scripts

Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>

* model cpu offload in validation.

* add flux.2 readme

* add img2img and tests

* cpu offload in log validation

* Apply suggestions from code review

* fix

* up

* fixes

* remove i2i training tests for now.

---------

Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: linoytsaban <linoy@huggingface.co>

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
Co-authored-by: Daniel Gu <dgu8957@gmail.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-10-53-87-203.ec2.internal>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: linoytsaban <linoy@huggingface.co>
2025-11-25 21:49:04 +05:30
Jerry Wu 4088e8a851 Add Support for Z-Image Series (#12703)
* Add Support for Z-Image.

* Reformatting with make style, black & isort.

* Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.

* modified main model forward, freqs_cis left

* refactored to add B dim

* fixed stack issue

* fixed modulation bug

* fixed modulation bug

* fix bug

* remove value_from_time_aware_config

* styling

* Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.

* Replace padding with pad_sequence; Add gradient checkpointing.

* Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

* Fix Docstring and Make Style.

* Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7ed1.

* update z-image docstring

* Revert attention dispatcher

* update z-image docstring

* styling

* Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.

* Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.

* Remove einop dependency.

* remove redundant imports & make fix-copies

* fix import

---------

Co-authored-by: liudongyang <liudongyang0114@gmail.com>
2025-11-25 05:50:00 -10:00
Junsong Chen d33d9f6715 fix typo in docs (#12675)
* fix typo in docs

* Update docs/source/en/api/pipelines/sana_video.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-11-24 19:42:16 -08:00
sq dde8754ba2 Fix variable naming typos in community FluxControlNetFillInpaintPipeline (#12701)
- Fixed variable naming typos (maskkk -> mask_fill, mask_imagee -> mask_image_fill, masked_imagee -> masked_image_fill, masked_image_latentsss -> masked_latents_fill)

These changes improve code readability without affecting functionality.
2025-11-24 15:16:11 -08:00
cdutr fbcd3ba6b2 [i8n-pt] Fix grammar and expand Portuguese documentation (#12598)
* Updates Portuguese documentation for Diffusers library

Enhances the Portuguese documentation with:
- Restructured table of contents for improved navigation
- Added placeholder page for in-translation content
- Refined language and improved readability in existing pages
- Introduced a new page on basic Stable Diffusion performance guidance

Improves overall documentation structure and user experience for Portuguese-speaking users

* Removes untranslated sections from Portuguese documentation

Cleans up the Portuguese documentation table of contents by removing placeholder sections marked as "Em tradução" (In translation)

Removes the in_translation.md file and associated table of contents entries for sections that are not yet translated, improving documentation clarity
2025-11-24 14:07:32 -08:00
Sayak Paul d176f61fcf [core] support sage attention + FA2 through kernels (#12439)
* up

* support automatic dispatch.

* disable compile support for now./

* up

* flash too.

* document.

* up

* up

* up

* up
2025-11-24 16:58:07 +05:30
DefTruth 354d35adb0 bugfix: fix chrono-edit context parallel (#12660)
* bugfix: fix chrono-edit context parallel

* bugfix: fix chrono-edit context parallel

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Clean up comments in transformer_chronoedit.py

Removed unnecessary comments regarding parallelization in cross-attention.

* fix style

* fix qc

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-24 16:36:53 +05:30
SwayStar123 544ba677dd Add FluxLoraLoaderMixin to Fibo pipeline (#12688)
Update pipeline_bria_fibo.py
2025-11-24 13:31:31 +05:30
David El Malih 6f1042e36c Improve docstrings and type hints in scheduling_lms_discrete.py (#12678)
* Enhance type hints and docstrings in LMSDiscreteScheduler class

Updated type hints for function parameters and return types to improve code clarity and maintainability. Enhanced docstrings for several methods, providing clearer descriptions of their functionality and expected arguments. Notable changes include specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.

* docs: Add specific paper reference to `_convert_to_karras` docstring.

* Refactor `_convert_to_karras` docstring in DPMSolverSDEScheduler to include detailed descriptions and a specific paper reference, enhancing clarity and documentation consistency.
2025-11-21 10:18:09 -08:00
Pratim Dasude d5da453de5 Community Pipeline: FluxFillControlNetInpaintPipeline for FLUX Fill-Based Inpainting with ControlNet (#12649)
* new flux fill controlnet inpaint pipline

* Delete src/diffusers/pipelines/flux/pipline_flux_fill_controlnet_Inpaint.py

deleting from main flux pipeline

* Fluc_fill_controlnet community pipline

* Update README.md

* Apply style fixes
2025-11-19 16:18:46 -03:00
David El Malih 15370f8412 Improve docstrings and type hints in scheduling_pndm.py (#12676)
* Enhance docstrings and type hints in PNDMScheduler class

- Updated parameter descriptions to include default values and specific types using Literal for better clarity.
- Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method.
- Added type hints for method return types to enhance code readability and maintainability.

* Refactor docstring in PNDMScheduler class to enhance clarity

- Simplified the explanation of the method for computing the previous sample from the current sample.
- Updated the reference to the PNDM paper for better accessibility.
- Removed redundant notation explanations to streamline the documentation.
2025-11-19 09:36:41 -08:00
Dhruv Nair a96b145304 [CI] Fix failing Pipeline CPU tests (#12681)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-19 21:19:24 +05:30
Dhruv Nair 6d8973ffe2 [CI] Fix indentation issue in workflow files (#12685)
update
2025-11-19 09:30:04 +05:30
Sayak Paul ab71f3c864 [core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used.

* up

* refactor according to Dhruv's ideas.

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: dn6 <dhruv@huggingface.co>

* up

---------

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-19 08:19:00 +05:30
Dhruv Nair b7df4a5387 [CI] Temporarily pin transformers (#12677)
* update

* update

* update

* update
2025-11-18 14:43:06 +05:30
dg845 67dc65e2e3 Revert AutoencoderKLWan's dim_mult default value back to list (#12640)
Revert dim_mult back to list and fix type annotation
2025-11-17 18:39:53 +05:30
Dhruv Nair 3579fdabf9 [CI] Make CI logs less verbose (#12674)
update
2025-11-17 14:23:09 +05:30
Junsong Chen 1afc21855e SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)
* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify;

* fix bug and run text/image-to-vidoe success;

* make style; quality; fix-copies;

* add sana image-to-video pipeline in markdown;

* add test case for sana image-to-video;

* make style;

* add a init file in sana-video test dir;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* minor update;

* fix bug and skip fp16 save test;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* add copied from for `encode_prompt`

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-17 00:23:34 -08:00
David Bertoin 0c35b580fe [PRX pipeline]: add 1024 resolution ratio bins (#12670)
add 1024 ratio bins
2025-11-17 10:37:40 +05:30
David Bertoin 01a56927f1 Rope in float32 for mps or npu compatibility (#12665)
rope in float32
2025-11-15 20:44:34 +05:30
dg845 a9e4883b6a Update Wan Animate Docs (#12658)
* Update the Wan Animate docs to reflect the most recent code

* Further explain input preprocessing and link to original Wan Animate preprocessing scripts
2025-11-14 16:06:22 -08:00
David El Malih 63dd601758 Improve docstrings and type hints in scheduling_euler_discrete.py (#12654)
* refactor: enhance type hints and documentation in EulerDiscreteScheduler

Updated type hints for function parameters and return types in the EulerDiscreteScheduler class to improve code clarity and maintainability. Enhanced docstrings for several methods to provide clearer descriptions of their functionality and expected arguments. This includes specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.

* refactor: enhance type hints and documentation across multiple schedulers

Updated type hints and improved docstrings in various scheduler classes, including CMStochasticIterativeScheduler, CosineDPMSolverMultistepScheduler, and others. This includes specifying parameter types, return types, and providing clearer descriptions of method functionalities. Notable changes include the addition of default values in the begin_index argument and enhanced explanations for noise addition methods. These improvements aim to enhance code clarity and maintainability across the scheduling module.

* refactor: update docstrings to clarify noise schedule construction

Revised docstrings across multiple scheduler classes to enhance clarity regarding the construction of noise schedules. Updated references to relevant papers, ensuring accurate citations for the methodologies used. This includes changes in DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, and others, improving documentation consistency and readability.
2025-11-14 15:12:24 -08:00
Dhruv Nair eeae0338e7 [Modular] Add Custom Blocks guide to doc (#12339)
* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/_toctree.yml

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-11-14 10:59:59 +05:30
David El Malih 3c1ca869d7 Improve docstrings and type hints in scheduling_ddpm.py (#12651)
* Enhance type hints and docstrings in scheduling_ddpm.py

- Added type hints for function parameters and return types across the DDPMScheduler class and related functions.
- Improved docstrings for clarity, including detailed descriptions of parameters and return values.
- Updated the alpha_transform_type and beta_schedule parameters to use Literal types for better type safety.
- Refined the _get_variance and previous_timestep methods with comprehensive documentation.

* Refactor docstrings and type hints in scheduling_ddpm.py

- Cleaned up whitespace in the rescale_zero_terminal_snr function.
- Enhanced the variance_type parameter in the DDPMScheduler class with improved formatting for better readability.
- Updated the docstring for the compute_variance method to maintain consistency and clarity in parameter descriptions and return values.

* Apply `make fix-copies`

* Refactor type hints across multiple scheduler files

- Updated type hints to include `Literal` for improved type safety in various scheduling files.
- Ensured consistency in type hinting for parameters and return types across the affected modules.
- This change enhances code clarity and maintainability.
2025-11-13 14:46:23 -08:00
David El Malih 6fe4a6ff8e Improve docstrings and type hints in scheduling_ddim.py (#12622)
* Improve docstrings and type hints in scheduling_ddim.py

- Add complete type hints for all function parameters
- Enhance docstrings to follow project conventions
- Add missing parameter descriptions

Fixes #9567

* Enhance docstrings and type hints in scheduling_ddim.py

- Update parameter types and descriptions for clarity
- Improve explanations in method docstrings to align with project standards
- Add optional annotations for parameters where applicable

* Refine type hints and docstrings in scheduling_ddim.py

- Update parameter types to use Literal for specific string options
- Enhance docstring descriptions for clarity and consistency
- Ensure all parameters have appropriate type annotations and defaults

* Apply review feedback on scheduling_ddim.py

- Replace "prevent singularities" with "avoid numerical instability" for better clarity
- Add backticks around `alpha_bar` variable name for consistent formatting
- Convert Imagen Video paper URLs to Hugging Face papers references

* Propagate changes using 'make fix-copies'

* Add missing Literal
2025-11-13 14:45:58 -08:00
Steven Liu 40de88af8c [docs] AutoModel (#12644)
* automodel

* fix
2025-11-13 08:43:24 -08:00
Steven Liu 6a2309b98d [utils] Update check_doc_toc (#12642)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-13 08:42:31 -08:00
Sayak Paul cd3bbe2910 skip autoencoderdl layerwise casting memory (#12647) 2025-11-13 12:56:22 +05:30
kaixuanliu 7a001c3ee2 adjust unit tests for test_save_load_float16 (#12500)
* adjust unit tests for wan pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* avoid adjusting common `get_dummy_components` API

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `form_pretrained` to `transformer` and `transformer_2`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-13 11:57:12 +05:30
dg845 d8e4805816 [WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526)
---------

Co-authored-by: Tolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
2025-11-12 16:52:31 -10:00
David El Malih 44c3101685 Improve docstrings and type hints in scheduling_amused.py (#12623)
* Improve docstrings and type hints in scheduling_amused.py

- Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk)
- Enhance AmusedSchedulerOutput with proper Optional typing
- Add comprehensive docstrings for AmusedScheduler class
- Improve __init__, set_timesteps, step, and add_noise methods
- Fix type hints to match documentation conventions
- All changes follow project standards from issue #9567

* Enhance type hints and docstrings in scheduling_amused.py

- Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types.
- Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device.
- Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters.
- Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency.

* Apply review feedback on scheduling_amused.py

- Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
2025-11-12 17:26:10 -08:00
YiYi Xu d6c63bb956 [modular] add a check (#12628)
* add

* fix
2025-11-12 07:59:18 -10:00
Steven Liu 2f44d63046 [docs] Update install instructions (#12626)
remove commit

Removed specific commit reference for installation instructions.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-11-12 09:21:24 -08:00
Quentin Gallouédec f3db38c1e7 ArXiv -> HF Papers (#12583)
* Update pipeline_skyreels_v2_i2v.py

* Update README.md

* Update torch_utils.py

* Update torch_utils.py

* Update guider_utils.py

* Update pipeline_ltx.py

* Update pipeline_bria.py

* Apply suggestion from @qgallouedec

* Update autoencoder_kl_qwenimage.py

* Update pipeline_prx.py

* Update pipeline_wan_vace.py

* Update pipeline_skyreels_v2.py

* Update pipeline_skyreels_v2_diffusion_forcing.py

* Update pipeline_bria_fibo.py

* Update pipeline_skyreels_v2_diffusion_forcing_i2v.py

* Update pipeline_ltx_condition.py

* Update pipeline_ltx_image2video.py

* Update regional_prompting_stable_diffusion.py

* make style

* style

* style
2025-11-12 08:37:21 -08:00
Sayak Paul f5e5f34823 [modular] add tests for qwen modular (#12585)
* add tests for qwenimage modular.

* qwenimage edit.

* qwenimage edit plus.

* empty

* align with the latest structure

* up

* up

* reason

* up

* fix multiple issues.

* up

* up

* fix

* up

* make it similar to the original pipeline.
2025-11-12 17:37:42 +05:30
YiYi Xu 093cd3f040 fix dispatch_attention_fn check (#12636)
* fix

* fix
2025-11-11 19:16:13 -10:00
a120092009 aecf0c53bf Add MLU Support. (#12629)
* Add MLU Support.

* fix comment.

* rename is_mlu_available to is_torch_mlu_available

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 19:15:26 -10:00
YiYi Xu 0c7589293b fix copies (#12637)
* fix

* remoce cocpies instead
2025-11-11 15:44:55 -10:00
Charchit Sharma ff263947ad Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------

Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
2025-11-11 11:45:36 -10:00
Dhruv Nair 66e6a0215f [CI] Remove unittest dependency from testing_utils.py (#12621)
* update

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 16:40:39 +05:30
Cesaryuan 5a47442f92 Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-10 13:57:52 -08:00
Dhruv Nair 8f6328c4a4 [Modular] Clean up docs (#12604)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-10 23:37:29 +05:30
Dhruv Nair 8d45f219d0 Fix Context Parallel validation checks (#12446)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-10 23:37:07 +05:30
Yashwant Bezawada 0fd58c7706 fix: correct import path for load_model_dict_into_meta in conversion scripts (#12616)
The function load_model_dict_into_meta was moved from modeling_utils.py to
model_loading_utils.py but the imports in the conversion scripts were not
updated, causing ImportError when running these scripts.

This fixes the import in 6 conversion scripts:
- scripts/convert_sd3_to_diffusers.py
- scripts/convert_stable_cascade_lite.py
- scripts/convert_stable_cascade.py
- scripts/convert_stable_audio.py
- scripts/convert_sana_to_diffusers.py
- scripts/convert_sana_controlnet_to_diffusers.py

Fixes #12606
2025-11-10 14:47:18 +05:30
Dhruv Nair 35d703310c [CI] Fix typo in uv install (#12618)
update
2025-11-10 13:22:46 +05:30
YiYi Xu b455dc94a2 [modular] wan! (#12611)
* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
2025-11-09 21:48:50 -10:00
Jay Wu 04f9d2bf3d add ChronoEdit (#12593)
* add ChronoEdit

* add ref to  original function & remove wan2.2 logics

* Update src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* add ChronoeEdit test

* add docs

* add docs

* make fix-copies

* fix chronoedit test

---------

Co-authored-by: wjay <wjay@nvidia.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-09 22:07:00 -08:00
Dhruv Nair bc8fd864eb [CI] Push test fix (#12617)
update
2025-11-10 09:26:14 +05:30
Wang, Yi a9cb08af39 fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled (#12562)
* fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* address review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-11-07 20:00:13 +05:30
DefTruth 9f669e7b5d feat: enable attention dispatch for huanyuan video (#12591)
* feat: enable attention dispatch for huanyuan video

* feat: enable attention dispatch for huanyuan video
2025-11-07 11:22:41 +05:30
Dhruv Nair 8ac17cd2cb [Modular] Some clean up for Modular tests (#12579)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-07 08:19:15 +05:30
Mohammad Sadegh Salehi e4393fa613 Fix overflow and dtype handling in rgblike_to_depthmap (NumPy + PyTorch) (#12546)
* Fix overflow in rgblike_to_depthmap by safe dtype casting (torch & NumPy)

* Fix: store original dtype and cast back after safe computation

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-06 08:18:21 -10:00
Junsong Chen b3e9dfced7 [SANA-Video] Adding 5s pre-trained 480p SANA-Video inference (#12584)
* 1. add `SanaVideoTransformer3DModel` in transformer_sana_video.py
2. add `SanaVideoPipeline` in pipeline_sana_video.py
3. add all code we need for import `SanaVideoPipeline`

* add a sample about how to use sana-video;

* code update;

* update hf model path;

* update code;

* sana-video can run now;

* 1. add aspect ratio in sana-video-pipeline;
2. add reshape function in sana-video-processor;
3. fix convert pth to safetensor bugs;

* default to use `use_resolution_binning`;

* make style;

* remove unused code;

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* support `dispatch_attention_fn`

* 1. add sana-video markdown;
2. fix typos;

* add two test case for sana-video (need check)

* fix text-encoder in test-sana-video;

* Update tests/pipelines/sana/test_sana_video.py

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/video_processor.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* make style
make quality
make fix-copies

* toctree yaml update;

* add sana-video-transformer3d markdown;

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-05 21:08:47 -08:00
Joseph Turian 58f3771545 Add optional precision-preserving preprocessing for examples/unconditional_image_generation/train_unconditional.py (#12596)
* Add optional precision-preserving preprocessing

* Document decoder caveat for precision flag

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-06 09:37:31 +05:30
Dhruv Nair 6198f8a12b [Modular] Allow ModularPipeline to load from revisions (#12592)
* update

* update

* update

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-06 07:54:24 +05:30
Linoy Tsaban dcfb18a2d3 [LoRA] add support for more Qwen LoRAs (#12581)
* fix bug when offload and cache_latents both enabled

* fix
2025-11-04 14:27:25 +02:00
Sayak Paul ac5a1e28fc [docs] sort doc (#12586)
sort doc
2025-11-04 10:26:07 +05:30
Lev Novitskiy 325a95051b Kandinsky 5.0 Docs fixes (#12582)
* add transformer pipeline first version

* updates

* fix 5sec generation

* rewrite Kandinsky5T2VPipeline to diffusers style

* add multiprompt support

* remove prints in pipeline

* add nabla attention

* Wrap Transformer in Diffusers style

* fix license

* fix prompt type

* add gradient checkpointing and peft support

* add usage example

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* remove unused imports

* add 10 second models support

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove no_grad and simplified prompt paddings

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved template to __init__

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved sdps inside processor

* remove oneline function

* remove reset_dtype methods

* Transformer: move all methods to forward

* separated prompt encoding

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fixed

* style +copies

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Charles <charles@huggingface.co>

* more

* Apply suggestions from code review

* add lora loader doc

* add compiled Nabla Attention

* all needed changes for 10 sec models are added!

* add docs

* Apply style fixes

* update docs

* add kandinsky5 to toctree

* add tests

* fix tests

* Apply style fixes

* update tests

* minor docs refactoring

* refactor Kandinsky 5.0 Vide docs

* Update docs/source/en/_toctree.yml

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Charles <charles@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-03 14:38:07 -10:00
Wang, Yi 1ec28a2c77 ulysses enabling in native attention path (#12563)
* ulysses enabling in native attention path

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* address review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add supports_context_parallel for native attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update templated attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-03 11:48:20 -10:00
YiYi Xu de6173c683 [modular]pass hub_kwargs to load_config (#12577)
pass hub_kwargs to load_config
2025-11-03 09:44:42 -10:00
Sayak Paul 8f80dda193 [tests] add tests for flux modular (t2i, i2i, kontext) (#12566)
* start flux modular tests.

* up

* add kontext

* up

* up

* up

* Update src/diffusers/modular_pipelines/flux/denoise.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* up

* up

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-02 10:51:11 +05:30
YiYi Xu cdbf0ad883 [modular] better warn message (#12573)
better warn message
2025-11-01 18:45:09 -10:00
Dhruv Nair 5e8415a311 Fix custom code loading in Automodel (#12571)
update
2025-11-01 17:04:31 -10:00
674 changed files with 37316 additions and 10762 deletions
+21 -7
View File
@@ -73,6 +73,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -84,7 +86,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -126,6 +128,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -138,7 +142,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -151,7 +155,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v --make-reports=examples_torch_cuda \
--make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -190,6 +194,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -198,7 +204,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,6 +238,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -281,6 +289,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -293,7 +303,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -358,6 +368,8 @@ jobs:
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
uv pip install pytest-reportlog
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -405,6 +417,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip install -U bitsandbytes optimum_quanto
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -531,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -587,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
+1 -1
View File
@@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install -e .
+3 -2
View File
@@ -109,7 +109,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -120,7 +121,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
+10 -8
View File
@@ -35,7 +35,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -55,7 +55,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -115,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -126,7 +127,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -134,7 +135,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
-k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -246,7 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -255,11 +257,11 @@ jobs:
- name: Run fast PyTorch LoRA tests with PEFT
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_peft_main \
tests/lora/
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
+20 -17
View File
@@ -1,4 +1,4 @@
name: Fast GPU Tests on PR
name: Fast GPU Tests on PR
on:
pull_request:
@@ -36,7 +36,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -56,7 +56,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -71,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -149,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
-k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
fi
- name: Failure short reports
if: ${{ failure() }}
@@ -201,7 +202,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -222,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
else
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -262,7 +264,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment
@@ -274,7 +277,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
@@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install -e .
+11 -5
View File
@@ -76,6 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -86,7 +88,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -127,6 +129,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -139,7 +143,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -178,6 +182,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -186,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -227,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -270,7 +276,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
+1 -1
View File
@@ -57,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
+1 -1
View File
@@ -47,7 +47,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
+6 -6
View File
@@ -84,7 +84,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -137,7 +137,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -187,7 +187,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -240,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -281,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -326,7 +326,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+24 -6
View File
@@ -22,6 +22,8 @@
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
- local: using-diffusers/automodel
title: AutoModel
- local: using-diffusers/other-formats
title: Model formats
- local: using-diffusers/push_to_hub
@@ -119,6 +121,8 @@
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
title: Modular Diffusers
- isExpanded: false
sections:
@@ -329,6 +333,8 @@
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/chronoedit_transformer_3d
title: ChronoEditTransformer3DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
@@ -343,6 +349,8 @@
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/flux2_transformer
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer
@@ -373,6 +381,8 @@
title: QwenImageTransformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/sana_video_transformer3d
title: SanaVideoTransformer3DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
@@ -383,6 +393,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_animate_transformer_3d
title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
@@ -444,6 +456,8 @@
- sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- sections:
- local: api/pipelines/audioldm
title: AudioLDM
@@ -456,8 +470,6 @@
- local: api/pipelines/stable_audio
title: Stable Audio
title: Audio
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- sections:
- local: api/pipelines/amused
title: aMUSEd
@@ -515,12 +527,16 @@
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/flux2
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
@@ -529,8 +545,6 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
@@ -565,6 +579,8 @@
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/sana_video
title: Sana Video
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
@@ -626,18 +642,20 @@
- sections:
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/chronoedit
title: ChronoEdit
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/kandinsky5_video
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx_video
+6 -1
View File
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
## Flux2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
+1 -9
View File
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
# AutoModel
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
```python
from diffusers import AutoModel, AutoPipelineForText2Image
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
```
[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
## AutoModel
@@ -0,0 +1,32 @@
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# ChronoEditTransformer3DModel
A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
The model can be loaded with the following code snippet.
```python
from diffusers import ChronoEditTransformer3DModel
transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## ChronoEditTransformer3DModel
[[autodoc]] ChronoEditTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2Transformer2DModel
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel
@@ -0,0 +1,36 @@
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SanaVideoTransformer3DModel
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
The abstract from the paper is:
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
The model can be loaded with the following code snippet.
```python
from diffusers import SanaVideoTransformer3DModel
import torch
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SanaVideoTransformer3DModel
[[autodoc]] SanaVideoTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanAnimateTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanAnimateTransformer3DModel
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanAnimateTransformer3DModel
[[autodoc]] WanAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+156
View File
@@ -0,0 +1,156 @@
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# ChronoEdit
[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
### Image Editing
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image(
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cups liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
"The mouses pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacups floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=50,
guidance_scale=5.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
Optionally, enable **temporal reasoning** for improved physical consistency:
```py
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=29,
num_inference_steps=50,
guidance_scale=5.0,
enable_temporal_reasoning=True,
num_temporal_reasoning_steps=50,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
### Inference with 8-Step Distillation Lora
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
image = load_image(
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cups liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
"The mouses pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacups floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=8,
guidance_scale=1.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
## ChronoEditPipeline
[[autodoc]] ChronoEditPipeline
- all
- __call__
## ChronoEditPipelineOutput
[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
+33
View File
@@ -0,0 +1,33 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
> [!TIP]
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
>
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
- all
- __call__
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
# Kandinsky 5.0 Video
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
) # <--- Sett attention bakend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
@@ -24,9 +24,6 @@ The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
+189
View File
@@ -0,0 +1,189 @@
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Sana-Video
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
The abstract from the paper is:
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Generation Pipelines
<hfoptions id="generation pipelines">`
<hfoption id="Text-to-Video">
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
```python
pipe = SanaVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
torch_dtype=torch.bfloat16,
)
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.float32)
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana_video.mp4", fps=16)
```
</hfoption>
<hfoption id="Image-to-Video">
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
```python
pipe = SanaImageToVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
motion_scale = 30.0
video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana-i2v.mp4", fps=16)
```
</hfoption>
</hfoptions>
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = SanaVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
model_score = 30
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_prompt = f" motion score: {model_score}."
prompt = prompt + motion_prompt
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=6.0,
num_inference_steps=50
).frames[0]
export_to_video(output, "sana-video-output.mp4", fps=16)
```
## SanaVideoPipeline
[[autodoc]] SanaVideoPipeline
- all
- __call__
## SanaImageToVideoPipeline
[[autodoc]] SanaImageToVideoPipeline
- all
- __call__
## SanaVideoPipelineOutput
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
+226 -17
View File
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -249,6 +250,208 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
The project page: https://humanaigc.github.io/wan-animate
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
#### Usage
The Wan-Animate pipeline supports two modes of operation:
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
##### Prerequisites
Before using the pipeline, you need to preprocess your reference video to extract:
- **Pose video**: Contains skeletal keypoints representing body motion
- **Face video**: Contains facial feature representations for expression control
For replacement mode, you additionally need:
- **Background video**: The original video containing the scene
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
> [!NOTE]
> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
The example below demonstrates how to use the Wan-Animate pipeline:
<hfoptions id="Animate usage">
<hfoption id="Animation mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load character image and preprocessed videos
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
# Resize image to match VAE constraints
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
# Generate animated video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
guidance_scale=1.0,
mode="animate", # Animation mode (default)
).frames[0]
export_to_video(output, "animated_character.mp4", fps=30)
```
</hfoption>
<hfoption id="Replacement mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load all required inputs for replacement mode
image = load_image("path/to/new_character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
background_video = load_video("path/to/background_video.mp4") # Original scene
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
# Resize image to match video dimensions
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
# Replace character in background video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_lengths=77,
guidance_scale=1.0,
mode="replace", # Replacement mode
).frames[0]
export_to_video(output, "character_replaced.mp4", fps=30)
```
</hfoption>
<hfoption id="Advanced options">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4")
face_video = load_video("path/to/face_video.mp4")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio"
negative_prompt = "blurry, low quality"
# Advanced: Use temporal guidance and custom callback
def callback_fn(pipe, step_index, timestep, callback_kwargs):
# You can modify latents or other tensors here
print(f"Step {step_index}, Timestep {timestep}")
return callback_kwargs
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
num_inference_steps=50,
guidance_scale=5.0,
prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
callback_on_step_end=callback_fn,
callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
export_to_video(output, "animated_advanced.mp4", fps=30)
```
</hfoption>
</hfoptions>
#### Key Parameters
- **mode**: Choose between `"animate"` (default) or `"replace"`
- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
@@ -281,10 +484,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
@@ -359,6 +562,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
## WanAnimatePipeline
[[autodoc]] WanAnimatePipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
@@ -0,0 +1,492 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Building Custom Blocks
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
> [!TIP]
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
## Project Structure
Your custom block project should use the following structure:
```shell
.
├── block.py
└── modular_config.json
```
- `block.py` contains the custom block implementation
- `modular_config.json` contains the metadata needed to load the block
## Example: Florence 2 Inpainting Block
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
```py
# Inside block.py
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ComponentSpec,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
```
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
```
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
def get_annotations(self, components, images, prompts, task):
task_prompts = [task + prompt for prompt in prompts]
inputs = components.image_annotator_processor(
text=task_prompts, images=images, return_tensors="pt"
).to(components.image_annotator.device, components.image_annotator.dtype)
generated_ids = components.image_annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = components.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
outputs = []
for image, annotation in zip(images, annotations):
outputs.append(
components.image_annotator_processor.post_process_generation(
annotation, task=task, image_size=(image.width, image.height)
)
)
return outputs
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
masks = []
for image, annotation in zip(images, annotations):
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
for _, _annotation in annotation.items():
if "polygons" in _annotation:
for polygon in _annotation["polygons"]:
polygon = np.array(polygon).reshape(-1, 2)
if len(polygon) < 3:
continue
polygon = polygon.reshape(-1).tolist()
draw.polygon(polygon, fill=fill)
elif "bbox" in _annotation:
bbox = _annotation["bbox"]
draw.rectangle(bbox, fill="white")
masks.append(mask_image)
return masks
def prepare_bounding_boxes(self, images, annotations):
outputs = []
for image, annotation in zip(images, annotations):
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for _, _annotation in annotation.items():
bbox = _annotation["bbox"]
label = _annotation["label"]
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
outputs.append(image_copy)
return outputs
def prepare_inputs(self, images, prompts):
prompts = prompts or ""
if isinstance(images, Image.Image):
images = [images]
if isinstance(prompts, str):
prompts = [prompts]
if len(images) != len(prompts):
raise ValueError("Number of images and annotation prompts must match.")
return images, prompts
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images, annotation_task_prompt = self.prepare_inputs(
block_state.image, block_state.annotation_prompt
)
task = block_state.annotation_task
fill = block_state.fill
annotations = self.get_annotations(
components, images, annotation_task_prompt, task
)
block_state.annotations = annotations
if block_state.annotation_output_type == "mask_image":
block_state.mask_image = self.prepare_mask(images, annotations)
else:
block_state.mask_image = None
if block_state.annotation_output_type == "mask_overlay":
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
elif block_state.annotation_output_type == "bounding_box":
block_state.image = self.prepare_bounding_boxes(images, annotations)
self.set_block_state(state, block_state)
return components, state
```
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
<hfoptions id="share">
<hfoption id="hf CLI">
```shell
# In the folder with the `block.py` file, run:
diffusers-cli custom_block
```
Then upload the block to the Hub:
```shell
hf upload <your repo id> . .
```
</hfoption>
<hfoption id="push_to_hub">
```py
from block import Florence2ImageAnnotatorBlock
block = Florence2ImageAnnotatorBlock()
block.push_to_hub("<your repo id>")
```
</hfoption>
</hfoptions>
## Using Custom Blocks
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
my_blocks = INPAINT_BLOCKS.copy()
# insert the annotation block before the image encoding step
my_blocks.insert("image_annotator", image_annotator_block, 1)
# Create our initial set of inpainting blocks
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
pipe = blocks.init_pipeline(repo_id)
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
image = image.resize((1024, 1024))
prompt = ["A red car"]
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
annotation_prompt = ["the car"]
output = pipe(
prompt=prompt,
image=image,
annotation_task=annotation_task,
annotation_prompt=annotation_prompt,
annotation_output_type="mask_image",
num_inference_steps=35,
guidance_scale=7.5,
strength=0.95,
output="images"
)
output[0].save("florence-inpainting.png")
```
## Editing Custom Blocks
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
```
Any changes made to the block files in this folder will be reflected when you load the block again.
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
- `__call__` method defines the loop structure and iteration logic.
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
```
```
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
]
```
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
Use `InputParam` to define `intermediate_inputs`.
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
@@ -76,7 +66,7 @@ def __call__(self, components, state):
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
@@ -112,4 +102,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
```
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
<hfoptions id="sequential">
<hfoption id="InputBlock">
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
```py
print(blocks)
print(blocks.doc)
```
```
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
@@ -0,0 +1,46 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoModel
The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
```py
import torch
from diffusers import AutoModel, DiffusionPipeline
transformer = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
)
text_encoder = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
)
```
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
+8 -6
View File
@@ -1,8 +1,10 @@
- sections:
- local: index
title: 🧨 Diffusers
- local: quicktour
title: Tour rápido
- local: installation
title: Instalação
- local: index
title: Diffusers
- local: installation
title: Instalação
- local: quicktour
title: Tour rápido
- local: stable_diffusion
title: Desempenho básico
title: Primeiros passos
+2 -2
View File
@@ -18,11 +18,11 @@ specific language governing permissions and limitations under the License.
# Diffusers
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
A Biblioteca tem três componentes principais:
- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
- Pipelines de última geração para a geração em poucas linhas de código. muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
+3 -3
View File
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
Um ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
Comece criando um ambiente virtual no diretório do projeto:
@@ -100,12 +100,12 @@ pip install -e ".[flax]"
</jax>
</frameworkcontent>
Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
Esses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
> [!WARNING]
> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
+132
View File
@@ -0,0 +1,132 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
[[open-in-colab]]
# Desempenho básico
Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
## Uso de memória
Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
pipeline.enable_model_cpu_offload()
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
pipeline(prompt).images[0]
print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
## Velocidade de inferência
O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
```py
import torch
import time
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
```py
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
start_time = time.perf_counter()
image = pipeline(prompt).images[0]
end_time = time.perf_counter()
print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
```
## Qualidade de geração
Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
negative_prompt = "low quality, blurry, ugly, poor details"
pipeline(prompt, negative_prompt=negative_prompt).images[0]
```
Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
```py
import torch
from diffusers import DiffusionPipeline, HeunDiscreteScheduler
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
negative_prompt = "low quality, blurry, ugly, poor details"
pipeline(prompt, negative_prompt=negative_prompt).images[0]
```
## Próximos passos
Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
+105 -2
View File
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
@@ -5527,3 +5527,106 @@ images = pipe(
).images
images[0].save("pizzeria.png")
```
# Flux Fill ControlNet Pipeline
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
## Example Usage
```python
import torch
from diffusers import (
FluxControlNetModel,
FluxPriorReduxPipeline,
)
from diffusers.utils import load_image
# NEW PIPELINE (updated name)
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Models
base_model = "black-forest-labs/FLUX.1-Fill-dev"
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
# Load ControlNet
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=dtype,
)
# Load Fill + ControlNet Pipeline
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=dtype,
).to(device)
# OPTIONAL FP8
# fill_pipe.transformer.enable_layerwise_casting(
# storage_dtype=torch.float8_e4m3fn,
# compute_dtype=torch.bfloat16
# )
# OPTIONAL Prior Redux
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
# prior_model,
# torch_dtype=dtype,
#).to(device)
# Inputs
# combined_image = load_image("person_input.png")
# 1. Prior conditioning
#prior_out = pipe_prior_redux(
# image=cloth_image,
# prompt=cloth_prompt,
#)
# 2. Fill Inpaint with ControlNet
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
img = load_image(r"imgs/background.jpg")
mask = load_image(r"imgs/mask.png")
control_image_depth = load_image(r"imgs/dog_depth _2.png")
result = fill_pipe(
prompt="a dog on a bench",
image=img,
mask_image=mask,
control_image=control_image_depth,
control_mode=[2], # union mode
control_guidance_start=0.0,
control_guidance_end=0.8,
controlnet_conditioning_scale=0.9,
height=1024,
width=1024,
strength=1.0,
guidance_scale=50.0,
num_inference_steps=60,
max_sequence_length=512,
# **prior_out,
)
# result.images[0].save("flux_fill_controlnet_inpaint.png")
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
```
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA")
image = image.convert("RGB")
+11 -6
View File
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
File diff suppressed because it is too large Load Diff
@@ -490,7 +490,7 @@ class RegionalPromptingStableDiffusionPipeline(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -841,7 +841,7 @@ class RegionalPromptingStableDiffusionPipeline(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
@@ -872,7 +872,7 @@ class RegionalPromptingStableDiffusionPipeline(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1062,7 +1062,7 @@ class RegionalPromptingStableDiffusionPipeline(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
+1 -2
View File
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
+315
View File
@@ -0,0 +1,315 @@
# DreamBooth training example for FLUX.2 [dev]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
### Remote Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2"
accelerate launch train_dreambooth_lora_flux2.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-lora"
accelerate launch train_dreambooth_lora_flux2.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
**important**
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
To start, you must have a dataset containing triplets:
* Condition image - the input image to be transformed.
* Target image - the desired output image after transformation.
* Instruction - a text prompt describing the transformation from the condition image to the target image.
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux2_img2img.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
--output_dir="flux2-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
### Aspect Ratio Bucketing
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
`
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
@@ -0,0 +1,262 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "dog"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
def test_dreambooth_lora_flux2(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--max_sequence_length 8
--checkpointing_steps=2
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
Below, we explain both in more detail.
#### Provide the dataset as a folder
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape)
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
"""
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
"""
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
channels = tensor.shape[0]
if channels == 3:
return tensor
if channels == 1:
return tensor.repeat(3, 1, 1)
if channels == 2:
return torch.cat([tensor, tensor[:1]], dim=0)
if channels > 3:
return tensor[:3]
raise ValueError(f"Unsupported number of channels: {channels}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--preserve_input_precision",
action="store_true",
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
spatial_augmentations = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]
augmentations = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
spatial_augmentations
+ [
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
precision_augmentations = transforms.Compose(
[
transforms.PILToTensor(),
transforms.Lambda(_ensure_three_channels),
transforms.ConvertImageDtype(torch.float32),
]
+ spatial_augmentations
+ [transforms.Normalize([0.5], [0.5])]
)
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
processed = []
for image in examples["image"]:
if not args.preserve_input_precision:
processed.append(augmentations(image.convert("RGB")))
else:
precise_image = image
if precise_image.mode == "P":
precise_image = precise_image.convert("RGB")
processed.append(precision_augmentations(precise_image))
return {"input": processed}
logger.info(f"Dataset size: {len(dataset)}")
+475
View File
@@ -0,0 +1,475 @@
import argparse
from contextlib import nullcontext
from typing import Any, Dict, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils.import_utils import is_accelerate_available
"""
# VAE
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
--vae_filename "flux2-vae.sft" \
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
--vae
# DiT
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--dit \
--output_path .
# Full pipe
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--vae_filename "flux2-vae.sft" \
--dit --vae --full_pipe \
--output_path .
"""
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
parser.add_argument("--dit_dtype", type=str, default="bf16")
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--full_pipe", action="store_true")
parser.add_argument("--output_path", type=str)
args = parser.parse_args()
def load_original_checkpoint(args, filename):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
"quant_conv.weight": "encoder.quant_conv.weight",
"quant_conv.bias": "encoder.quant_conv.bias",
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
"bn.running_mean": "bn.running_mean",
"bn.running_var": "bn.running_var",
}
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
# proj_attn.weight has to be converted from conv 1D to linear
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
new_checkpoint = {}
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
if ldm_key not in vae_state_dict:
continue
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len(config["down_block_types"])
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
)
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.bias"
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len(config["up_block_types"])
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
)
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
# Image and text input projections
"img_in": "x_embedder",
"txt_in": "context_embedder",
# Timestep and guidance embeddings
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
# Modulation parameters
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
# Final output layer
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
"final_layer.linear": "proj_out",
}
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear",
}
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
# Handle fused QKV projections separately as we need to break into Q, K, V projections
"img_attn.norm.query_norm": "attn.norm_q",
"img_attn.norm.key_norm": "attn.norm_k",
"img_attn.proj": "attn.to_out.0",
"img_mlp.0": "ff.linear_in",
"img_mlp.2": "ff.linear_out",
"txt_attn.norm.query_norm": "attn.norm_added_q",
"txt_attn.norm.key_norm": "attn.norm_added_k",
"txt_attn.proj": "attn.to_add_out",
"txt_mlp.0": "ff_context.linear_in",
"txt_mlp.2": "ff_context.linear_out",
}
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
"linear1": "attn.to_qkv_mlp_proj",
"linear2": "attn.to_out",
"norm.query_norm": "attn.norm_q",
"norm.key_norm": "attn.norm_k",
}
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
# diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight
if ".weight" not in key:
return
# If adaLN_modulation is in the key, swap scale and shift parameters
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
if "adaLN_modulation" in key:
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
# Assume all such keys are in the AdaLayerNorm key map
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
new_key = ".".join([new_key_without_param_type, param_type])
swapped_weight = swap_scale_shift(state_dict.pop(key))
state_dict[new_key] = swapped_weight
return
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
new_prefix = "transformer_blocks"
if "double_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
if "qkv" in within_block_name:
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
if "img" in modality_block_name:
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.to_q"
new_k_name = "attn.to_k"
new_v_name = "attn.to_v"
elif "txt" in modality_block_name:
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.add_q_proj"
new_k_name = "attn.add_k_proj"
new_v_name = "attn.add_v_proj"
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
state_dict[new_q_key] = to_q_weight
state_dict[new_k_key] = to_k_weight
state_dict[new_v_key] = to_v_weight
else:
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
# Mapping:
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
new_prefix = "single_transformer_blocks"
if "single_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaLN_modulation": convert_ada_layer_norm_weights,
"double_blocks": convert_flux2_double_stream_blocks,
"single_blocks": convert_flux2_single_stream_blocks,
}
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "test" or model_type == "dummy-flux2":
config = {
"model_id": "diffusers-internal-dev/dummy-flux2",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 8,
"num_single_layers": 48,
"attention_head_dim": 128,
"num_attention_heads": 48,
"joint_attention_dim": 15360,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def main(args):
if args.vae:
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
vae = AutoencoderKLFlux2()
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if not args.full_pipe:
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
if args.full_pipe:
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
generate_config.do_sample = True
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
)
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
pipe = Flux2Pipeline(
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
pipe.save_pretrained(args.output_path)
if __name__ == "__main__":
main(args)
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
from diffusers import (
SanaControlNetModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+327
View File
@@ -0,0 +1,327 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import os
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaVideoPipeline,
SanaVideoTransformer3DModel,
UniPCMultistepScheduler,
)
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
def main(args):
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
snapshot_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
cache_dir=cache_dir_path,
repo_type="model",
)
file_path = hf_hub_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
cache_dir=cache_dir_path,
repo_type="model",
)
else:
file_path = args.orig_ckpt_path
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
all_state_dict = torch.load(file_path, weights_only=True)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 8.0
if args.task == "i2v":
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
# Positional embedding interpolation scale.
qk_norm = True
# sample size
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
patch_size = (1, 1, 1)
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Linear Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.point_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.t_conv.weight"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"caption_channels": 2304,
"mlp_ratio": 3.0,
"attention_bias": False,
"sample_size": sample_size,
"patch_size": patch_size,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 1024,
}
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
transformer = transformer.to(weight_dtype)
if not args.save_full_pipeline:
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole Pipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Choose the appropriate pipeline and scheduler based on model type
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
elif args.scheduler_type == "uni-pc":
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction",
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=flow_shift,
)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaVideoPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--video_size",
default=480,
type=int,
choices=[480, 720],
required=False,
help="Video size of pretrained model, 480 or 720.",
)
parser.add_argument(
"--model_type",
default="SanaVideo",
type=str,
choices=[
"SanaVideo",
],
)
parser.add_argument(
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
main(args)
+1 -1
View File
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -18,7 +18,7 @@ from diffusers import (
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+265 -6
View File
@@ -6,11 +6,20 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
CLIPVisionModelWithProjection,
UMT5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"after_proj": "proj_out",
}
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"cross_attn.k_img": "attn2.to_k_img",
"cross_attn.v_img": "attn2.to_v_img",
"cross_attn.norm_k_img": "attn2.norm_k_img",
# After cross_attn -> attn2 rename, we need to rename the img keys
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
# Motion encoder mappings
# The name mapping is complicated for the convolutional part so we handle that in its own function
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
"face_encoder.conv2.conv": "face_encoder.conv2",
"face_encoder.conv3.conv": "face_encoder.conv3",
# Face adapter mappings are handled in a separate function
}
# TODO: Verify this and simplify if possible.
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
"""
Convert all motion encoder weights for Animate model.
In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
Conversion strategy:
1. Drop .kernel buffers (blur kernels)
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
"""
# Skip if not a weight, bias, or kernel
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
return
# Handle Blur kernel buffers from original implementation.
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
if ".kernel" in key and "motion_encoder" in key:
# Remove unexpected blur kernel buffers to avoid strict load errors
state_dict.pop(key, None)
return
# Rename Sequential indices to named components in ConvLayer and ResBlock
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
parts = key.split(".")
# Find the sequential index (digit) after convs or after conv1/conv2/skip
# Examples:
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
# - enc.net_app.convs.8 -> conv_out (final conv layer)
convs_idx = parts.index("convs") if "convs" in parts else -1
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
bias = False
# The nn.Sequential index will always follow convs
sequential_idx = int(parts[convs_idx + 1])
if sequential_idx == 0:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_in.weight"
elif key.endswith(".bias"):
new_key = "motion_encoder.conv_in.act_fn.bias"
bias = True
elif sequential_idx == final_conv_idx:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_out.weight"
else:
# Intermediate .convs. layers, which get mapped to .res_blocks.
prefix = "motion_encoder.res_blocks."
layer_name = parts[convs_idx + 2]
if layer_name == "skip":
layer_name = "conv_skip"
if key.endswith(".weight"):
param_name = "weight"
elif key.endswith(".bias"):
param_name = "act_fn.bias"
bias = True
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
param = state_dict.pop(key)
if bias:
param = param.squeeze()
state_dict[new_key] = param
return
return
return
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert face adapter weights for the Animate model.
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
"""
# Skip if not a weight or bias
if ".weight" not in key and ".bias" not in key:
return
prefix = "face_adapter."
if ".fuser_blocks." in key:
parts = key.split(".")
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
block_idx = parts[module_list_idx + 1]
layer_name = parts[module_list_idx + 2]
param_name = parts[module_list_idx + 3]
if layer_name == "linear1_kv":
layer_name_k = "to_k"
layer_name_v = "to_v"
suffix_k = ".".join([block_idx, layer_name_k, param_name])
suffix_v = ".".join([block_idx, layer_name_v, param_name])
new_key_k = prefix + suffix_k
new_key_v = prefix + suffix_v
kv_proj = state_dict.pop(key)
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
state_dict[new_key_k] = k_proj
state_dict[new_key_v] = v_proj
return
else:
if layer_name == "q_norm":
new_layer_name = "norm_q"
elif layer_name == "k_norm":
new_layer_name = "norm_k"
elif layer_name == "linear1_q":
new_layer_name = "to_q"
elif layer_name == "linear2":
new_layer_name = "to_out"
suffix_parts = [block_idx, new_layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
state_dict[new_key] = state_dict.pop(key)
return
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"motion_encoder": convert_animate_motion_encoder_weights,
"face_adapter": convert_animate_face_adapter_weights,
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": None,
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
"motion_style_dim": 512,
"motion_dim": 20,
"motion_encoder_dim": 512,
"face_encoder_hidden_dim": 1024,
"face_encoder_num_heads": 4,
"inject_face_latents_blocks": 5,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
# Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
# Move to CPU to ensure all tensors are materialized
transformer = transformer.to("cpu")
return transformer
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
+2 -2
View File
@@ -122,7 +122,7 @@ _deps = [
"pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.9.0",
"python>=3.8.0",
"ruff==0.9.10",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
@@ -287,7 +287,7 @@ setup(
packages=find_packages("src"),
package_data={"diffusers": ["py.typed"]},
include_package_data=True,
python_requires=">=3.10.0",
python_requires=">=3.8.0",
install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
+26
View File
@@ -186,6 +186,7 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLFlux2",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
@@ -202,6 +203,7 @@ else:
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"ChronoEditTransformer3DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
@@ -214,6 +216,7 @@ else:
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"Flux2Transformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
@@ -246,6 +249,7 @@ else:
"QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
"SanaVideoTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -266,8 +270,10 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageTransformer2DModel",
"attention_backend",
]
)
@@ -405,6 +411,7 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
]
@@ -435,6 +442,7 @@ else:
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
@@ -452,6 +460,7 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -540,10 +549,13 @@ else:
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -631,6 +643,7 @@ else:
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
@@ -638,6 +651,7 @@ else:
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImagePipeline",
]
)
@@ -891,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLFlux2,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
@@ -907,6 +922,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -919,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
@@ -951,6 +968,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -970,6 +988,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
@@ -1084,6 +1103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
)
@@ -1110,6 +1130,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
ChronoEditPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
@@ -1127,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
@@ -1215,10 +1237,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
SanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -1305,6 +1329,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
@@ -1312,6 +1337,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImagePipeline,
)
try:
+12 -12
View File
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, List
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME
@@ -33,13 +33,13 @@ class PipelineCallback(ConfigMixin):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
@property
def tensor_inputs(self) -> list[str]:
def tensor_inputs(self) -> List[str]:
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
@@ -49,14 +49,14 @@ class MultiPipelineCallbacks:
provides a unified interface for calling all of them.
"""
def __init__(self, callbacks: list[PipelineCallback]):
def __init__(self, callbacks: List[PipelineCallback]):
self.callbacks = callbacks
@property
def tensor_inputs(self) -> list[str]:
def tensor_inputs(self) -> List[str]:
return [input for callback in self.callbacks for input in callback.tensor_inputs]
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
"""
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
"""
@@ -76,7 +76,7 @@ class SDCFGCutoffCallback(PipelineCallback):
tensor_inputs = ["prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -109,7 +109,7 @@ class SDXLCFGCutoffCallback(PipelineCallback):
"add_time_ids",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -152,7 +152,7 @@ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"image",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -195,7 +195,7 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
tensor_inputs = []
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -219,7 +219,7 @@ class SD3CFGCutoffCallback(PipelineCallback):
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
+20 -18
View File
@@ -24,7 +24,7 @@ import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
@@ -94,10 +94,10 @@ class ConfigMixin:
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`list[str]`) -- A list of attributes that should not be saved in the config (should be
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
- **_deprecated_kwargs** (`list[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
@@ -143,7 +143,7 @@ class ConfigMixin:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def save_config(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
[`~ConfigMixin.from_config`] class method.
@@ -155,7 +155,7 @@ class ConfigMixin:
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
@@ -189,13 +189,13 @@ class ConfigMixin:
@classmethod
def from_config(
cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs
) -> Self | tuple[Self, dict[str, Any]]:
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
r"""
Instantiate a Python class from a config dictionary.
Parameters:
config (`dict[str, Any]`):
config (`Dict[str, Any]`):
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
files of compatible classes.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
@@ -292,11 +292,11 @@ class ConfigMixin:
@validate_hf_hub_args
def load_config(
cls,
pretrained_model_name_or_path: str | os.PathLike,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Load a model or scheduler configuration.
@@ -315,7 +315,7 @@ class ConfigMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
@@ -352,7 +352,7 @@ class ConfigMixin:
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
@@ -563,7 +563,9 @@ class ConfigMixin:
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: str | os.PathLike, dduf_entries: Optional[dict[str, DDUFEntry]] = None):
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
@@ -575,12 +577,12 @@ class ConfigMixin:
return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> dict[str, Any]:
def config(self) -> Dict[str, Any]:
"""
Returns the config of the class as a frozen dictionary
Returns:
`dict[str, Any]`: Config of the class.
`Dict[str, Any]`: Config of the class.
"""
return self._internal_dict
@@ -623,7 +625,7 @@ class ConfigMixin:
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str | os.PathLike):
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save the configuration instance's parameters to a JSON file.
@@ -635,7 +637,7 @@ class ConfigMixin:
writer.write(self.to_json_string())
@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: dict[str, DDUFEntry]):
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
@@ -754,7 +756,7 @@ class LegacyConfigMixin(ConfigMixin):
"""
@classmethod
def from_config(cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs):
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
# To prevent dependency import problem.
from .models.model_loading_utils import _fetch_remapped_cls_from_config
+1 -1
View File
@@ -29,7 +29,7 @@ deps = {
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.9.0",
"python": "python>=3.8.0",
"ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -79,7 +77,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
@@ -90,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
+16 -8
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -38,10 +36,10 @@ class AutoGuidance(BaseGuidance):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `list[int]`, *optional*):
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
@@ -67,8 +65,8 @@ class AutoGuidance(BaseGuidance):
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[int | list[int]] = None,
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
@@ -135,7 +133,7 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
@@ -143,6 +141,16 @@ class AutoGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -93,7 +91,7 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
@@ -101,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -79,7 +77,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
@@ -87,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -39,7 +37,7 @@ else:
build_laplacian_pyramid_func = None
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
(Algorithm 2).
@@ -60,7 +58,7 @@ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -
return v0_parallel, v0_orthogonal
def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
"""
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
(Algorithm 2).
@@ -101,19 +99,19 @@ class FrequencyDecoupledGuidance(BaseGuidance):
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
descending order).
guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
`guidance_scales`.
parallel_weights (`float` or `list[float]`, *optional*):
parallel_weights (`float` or `List[float]`, *optional*):
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
@@ -122,10 +120,10 @@ class FrequencyDecoupledGuidance(BaseGuidance):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float` or `list[float]`, defaults to `0.0`):
start (`float` or `List[float]`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
should be the same length as `guidance_scales`.
stop (`float` or `list[float]`, defaults to `1.0`):
stop (`float` or `List[float]`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
should be the same length as `guidance_scales`.
guidance_rescale_space (`str`, defaults to `"data"`):
@@ -143,12 +141,12 @@ class FrequencyDecoupledGuidance(BaseGuidance):
@register_to_config
def __init__(
self,
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
guidance_rescale: float | list[float] | tuple[float] = 0.0,
parallel_weights: Optional[float | list[float] | tuple[float]] = None,
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
use_original_formulation: bool = False,
start: float | list[float] | tuple[float] = 0.0,
stop: float | list[float] | tuple[float] = 1.0,
start: Union[float, List[float], Tuple[float]] = 0.0,
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
enabled: bool = True,
@@ -220,7 +218,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
@@ -228,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
+63 -15
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
@@ -53,8 +51,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: dict[str, str | tuple[str, str]] = None
self._enabled = True
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = enabled
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
@@ -103,7 +101,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
def get_state(self) -> dict[str, Any]:
def get_state(self) -> Dict[str, Any]:
"""
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
@@ -165,10 +163,15 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
pass
def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: list["BlockState"]) -> Any:
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
@@ -196,7 +199,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
data: dict[str, tuple[torch.Tensor, torch.Tensor]],
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
tuple_index: int,
identifier: str,
) -> "BlockState":
@@ -205,7 +208,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`dict[str, Union[str, tuple[str, str]]]`):
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
@@ -236,11 +239,56 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
def _prepare_batch_from_block_state(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[str | os.PathLike] = None,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
@@ -267,7 +315,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
@@ -297,7 +345,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a guider configuration object to a directory so that it can be reloaded using the
[`~BaseGuidance.from_pretrained`] class method.
@@ -309,7 +357,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@@ -325,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -60,10 +58,10 @@ class PerturbedAttentionGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
perturbed_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
perturbed_guidance_layers (`int` or `list[int]`, *optional*):
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
If not provided, `perturbed_guidance_config` must be provided.
perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -94,8 +92,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
perturbed_guidance_scale: float = 2.8,
perturbed_guidance_start: float = 0.01,
perturbed_guidance_stop: float = 0.2,
perturbed_guidance_layers: Optional[int | list[int]] = None,
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
@@ -171,7 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -189,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
+26 -8
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -66,11 +64,11 @@ class SkipLayerGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -96,8 +94,8 @@ class SkipLayerGuidance(BaseGuidance):
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[int | list[int]] = None,
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
@@ -167,7 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -185,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -56,11 +54,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `list[int]`, *optional*):
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
Diffusion 3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -88,8 +86,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[int | list[int]] = None,
seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
@@ -156,7 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -174,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -68,7 +66,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
@@ -76,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
+22 -2
View File
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Type
from typing import Any, Callable, Dict, Type
@dataclass
@@ -28,7 +28,7 @@ class TransformerBlockMetadata:
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: dict[str, int] = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
@@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
# AttnProcessor2_0
AttentionProcessorRegistry.register(
@@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
),
)
# ZSingleStreamAttnProcessor
AttentionProcessorRegistry.register(
model_class=ZSingleStreamAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
# BasicTransformerBlock
TransformerBlockRegistry.register(
@@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
),
)
# ZImage
TransformerBlockRegistry.register(
model_class=ZImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -338,4 +357,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
+11 -9
View File
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Type
from typing import Dict, List, Type, Union
import torch
@@ -42,7 +42,7 @@ _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: dict[str, int] = None
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
@@ -78,7 +78,7 @@ class ModuleForwardMetadata:
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: dict[str, ContextParallelModelPlan],
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
@@ -107,7 +107,7 @@ def apply_context_parallel(
registry.register_hook(hook, hook_name)
def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook):
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
return x
else:
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook):
@@ -272,13 +274,13 @@ class EquipartitionSharder:
return tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+22 -22
View File
@@ -14,7 +14,7 @@
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional, Tuple
import torch
@@ -60,7 +60,7 @@ class FasterCacheConfig:
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
@@ -68,17 +68,17 @@ class FasterCacheConfig:
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
process.
temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
The timestep range within which the temporal attention computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0).
low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
The timestep range within which the low frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
The timestep range within which the high frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
@@ -92,15 +92,15 @@ class FasterCacheConfig:
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound.
spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
@@ -123,7 +123,7 @@ class FasterCacheConfig:
is_guidance_distilled (`bool`, defaults to `False`):
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
_unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
@@ -135,12 +135,12 @@ class FasterCacheConfig:
spatial_attention_block_skip_range: int = 2
temporal_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
# 1 and 2 as mentioned in Equation 11 of the paper
alpha_low_frequency: float = 1.1
@@ -148,10 +148,10 @@ class FasterCacheConfig:
# n as described in CFG-Cache explanation in the paper - dependent on the model
unconditional_batch_skip_range: int = 5
unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
attention_weight_callback: Callable[[torch.nn.Module], float] = None
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
@@ -162,7 +162,7 @@ class FasterCacheConfig:
current_timestep_callback: Callable[[], int] = None
_unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
def __repr__(self) -> str:
return (
@@ -209,7 +209,7 @@ class FasterCacheBlockState:
def __init__(self) -> None:
self.iteration: int = 0
self.batch_size: int = None
self.cache: tuple[torch.Tensor, torch.Tensor] = None
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
def reset(self):
self.iteration = 0
@@ -223,10 +223,10 @@ class FasterCacheDenoiserHook(ModelHook):
def __init__(
self,
unconditional_batch_skip_range: int,
unconditional_batch_timestep_skip_range: tuple[int, int],
unconditional_batch_timestep_skip_range: Tuple[int, int],
tensor_format: str,
is_guidance_distilled: bool,
uncond_cond_input_kwargs_identifiers: list[str],
uncond_cond_input_kwargs_identifiers: List[str],
current_timestep_callback: Callable[[], int],
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
@@ -252,7 +252,7 @@ class FasterCacheDenoiserHook(ModelHook):
return module
@staticmethod
def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
# followed by conditional inputs.
_, cond = input.chunk(2, dim=0)
@@ -371,7 +371,7 @@ class FasterCacheBlockHook(ModelHook):
def __init__(
self,
block_skip_range: int,
timestep_skip_range: tuple[int, int],
timestep_skip_range: Tuple[int, int],
is_guidance_distilled: bool,
weight_callback: Callable[[torch.nn.Module], float],
current_timestep_callback: Callable[[], int],
+3 -2
View File
@@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
@@ -52,9 +53,9 @@ class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
+13 -13
View File
@@ -17,7 +17,7 @@ import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch
@@ -58,21 +58,21 @@ class GroupOffloadingConfig:
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[torch.cuda.Stream | torch.Stream] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup:
def __init__(
self,
modules: list[torch.nn.Module],
modules: List[torch.nn.Module],
offload_device: torch.device,
onload_device: torch.device,
offload_leader: torch.nn.Module,
onload_leader: Optional[torch.nn.Module] = None,
parameters: Optional[list[torch.nn.Parameter]] = None,
buffers: Optional[list[torch.Tensor]] = None,
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: torch.cuda.Stream | torch.Stream | None = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
@@ -340,7 +340,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self):
self.execution_order: list[tuple[str, torch.nn.Module]] = []
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
@@ -444,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: str | torch.device,
offload_device: str | torch.device = torch.device("cpu"),
offload_type: str | GroupOffloadingType = "block_level",
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -787,7 +787,7 @@ def _apply_lazy_group_offloading_hook(
def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> list[torch.nn.Parameter]:
) -> List[torch.nn.Parameter]:
parameters = []
for name, parameter in module.named_parameters():
has_parent_with_group_offloading = False
@@ -805,7 +805,7 @@ def _gather_parameters_with_no_group_offloading_parent(
def _gather_buffers_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> list[torch.Tensor]:
) -> List[torch.Tensor]:
buffers = []
for name, buffer in module.named_buffers():
has_parent_with_group_offloading = False
@@ -821,7 +821,7 @@ def _gather_buffers_with_no_group_offloading_parent(
return buffers
def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
+6 -6
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import functools
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple
import torch
@@ -86,19 +86,19 @@ class ModelHook:
"""
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> tuple[tuple[Any], dict[str, Any]]:
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`tuple[Any]`):
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`dict[Str, Any]`):
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`tuple[tuple[Any], dict[Str, Any]]`:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
@@ -168,7 +168,7 @@ class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
self.hooks: dict[str, ModelHook] = {}
self.hooks: Dict[str, ModelHook] = {}
self._module_ref = module_ref
self._hook_order = []
+3 -3
View File
@@ -14,7 +14,7 @@
import math
from dataclasses import asdict, dataclass
from typing import Callable, Optional
from typing import Callable, List, Optional
import torch
@@ -43,7 +43,7 @@ class LayerSkipConfig:
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`list[int]`):
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
@@ -63,7 +63,7 @@ class LayerSkipConfig:
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: list[int]
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
+7 -7
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import re
from typing import Optional, Type
from typing import Optional, Tuple, Type, Union
import torch
@@ -102,8 +102,8 @@ def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: str | tuple[str, ...] = "auto",
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
@@ -137,12 +137,12 @@ def apply_layerwise_casting(
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`tuple[str, ...]`, defaults to `"auto"`):
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
@@ -169,8 +169,8 @@ def _apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[tuple[str, ...]] = None,
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
@@ -14,7 +14,7 @@
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Union
import torch
@@ -54,20 +54,20 @@ class PyramidAttentionBroadcastConfig:
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be reused) before computing the new attention states again.
spatial_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`tuple[str, ...]`):
spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`tuple[str, ...]`):
temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`tuple[str, ...]`):
cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
@@ -75,13 +75,13 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: tuple[int, int] = (100, 800)
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
@@ -141,7 +141,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
_is_stateful = True
def __init__(
self, timestep_skip_range: tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
@@ -288,8 +288,8 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
def _apply_pyramid_attention_broadcast_hook(
module: Attention | MochiAttention,
timestep_skip_range: tuple[int, int],
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
@@ -299,7 +299,7 @@ def _apply_pyramid_attention_broadcast_hook(
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`tuple[int, int]`):
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
@@ -14,7 +14,7 @@
import math
from dataclasses import asdict, dataclass
from typing import Optional
from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -35,21 +35,21 @@ class SmoothedEnergyGuidanceConfig:
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`list[int]`):
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
_query_proj_identifiers (`list[str]`, defaults to `None`):
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
`None`, `to_q` is used by default.
"""
indices: list[int]
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: list[str] = None
_query_proj_identifiers: List[str] = None
def to_dict(self):
return asdict(self)
+2 -2
View File
@@ -21,8 +21,8 @@ def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
module_list_with_transformer_blocks = []
for name, submodule in module.named_modules():
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
is_ModuleList = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_ModuleList:
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_modulelist:
module_list_with_transformer_blocks.append((name, submodule))
return module_list_with_transformer_blocks
+82 -54
View File
@@ -14,7 +14,7 @@
import math
import warnings
from typing import Optional
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
@@ -26,9 +26,14 @@ from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = (
PIL.Image.Image | np.ndarray | torch.Tensor | list[PIL.Image.Image] | list[np.ndarray] | list[torch.Tensor]
)
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.Tensor],
]
PipelineDepthInput = PipelineImageInput
@@ -63,7 +68,7 @@ def is_valid_image_imagelist(images):
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, list]`):
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
@@ -126,7 +131,7 @@ class VaeImageProcessor(ConfigMixin):
)
@staticmethod
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a numpy image or a batch of images to a PIL image.
@@ -135,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
The image array to convert to PIL format.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images.
"""
if images.ndim == 3:
@@ -150,12 +155,12 @@ class VaeImageProcessor(ConfigMixin):
return pil_images
@staticmethod
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `list[PIL.Image.Image]`):
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
@@ -205,7 +210,7 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Normalize an image array to [-1,1].
@@ -220,7 +225,7 @@ class VaeImageProcessor(ConfigMixin):
return 2.0 * images - 1.0
@staticmethod
def denormalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Denormalize an image array to [0,1].
@@ -404,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
@@ -455,18 +460,18 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res
def resize(
self,
image: PIL.Image.Image | np.ndarray | torch.Tensor,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: int,
width: int,
resize_mode: str = "default", # "default", "fill", "crop"
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.
@@ -539,7 +544,7 @@ class VaeImageProcessor(ConfigMixin):
return image
def _denormalize_conditionally(
self, images: torch.Tensor, do_denormalize: Optional[list[bool]] = None
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
) -> torch.Tensor:
r"""
Denormalize a batch of images based on a condition list.
@@ -547,7 +552,7 @@ class VaeImageProcessor(ConfigMixin):
Args:
images (`torch.Tensor`):
The input image tensor.
do_denormalize (`Optional[list[bool]`, *optional*, defaults to `None`):
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
value of `do_normalize` in the `VaeImageProcessor` config.
"""
@@ -560,10 +565,10 @@ class VaeImageProcessor(ConfigMixin):
def get_default_height_width(
self,
image: PIL.Image.Image | np.ndarray | torch.Tensor,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> tuple[int, int]:
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
@@ -578,7 +583,7 @@ class VaeImageProcessor(ConfigMixin):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`tuple[int, int]`:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
@@ -611,7 +616,7 @@ class VaeImageProcessor(ConfigMixin):
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: str = "default", # "default", "fill", "crop"
crops_coords: Optional[tuple[int, int, int, int]] = None,
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
Preprocess the image input.
@@ -633,7 +638,7 @@ class VaeImageProcessor(ConfigMixin):
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
crops_coords (`list[tuple[int, int, int, int]]`, *optional*, defaults to `None`):
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
@@ -740,8 +745,8 @@ class VaeImageProcessor(ConfigMixin):
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[list[bool]] = None,
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
@@ -750,7 +755,7 @@ class VaeImageProcessor(ConfigMixin):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
@@ -791,7 +796,7 @@ class VaeImageProcessor(ConfigMixin):
mask: PIL.Image.Image,
init_image: PIL.Image.Image,
image: PIL.Image.Image,
crop_coords: Optional[tuple[int, int, int, int]] = None,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
r"""
Applies an overlay of the mask and the inpainted image on the original image.
@@ -803,7 +808,7 @@ class VaeImageProcessor(ConfigMixin):
The original image to which the overlay is applied.
image (`PIL.Image.Image`):
The image to overlay onto the original.
crop_coords (`tuple[int, int, int, int]`, *optional*):
crop_coords (`Tuple[int, int, int, int]`, *optional*):
Coordinates to crop the image. If provided, the image will be cropped accordingly.
Returns:
@@ -886,7 +891,7 @@ class InpaintProcessor(ConfigMixin):
height: int = None,
width: int = None,
padding_mask_crop: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess the image and mask.
"""
@@ -941,8 +946,8 @@ class InpaintProcessor(ConfigMixin):
output_type: str = "pil",
original_image: Optional[PIL.Image.Image] = None,
original_mask: Optional[PIL.Image.Image] = None,
crops_coords: Optional[tuple[int, int, int, int]] = None,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
"""
Postprocess the image, optionally apply mask overlay
"""
@@ -993,7 +998,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
super().__init__()
@staticmethod
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy image or a batch of images to a list of PIL images.
@@ -1002,7 +1007,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The input NumPy array of images, which can be a single image or a batch.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy array.
"""
if images.ndim == 3:
@@ -1017,12 +1022,12 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return pil_images
@staticmethod
def depth_pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`Union[list[PIL.Image.Image], PIL.Image.Image]`):
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
The input image or list of images to be converted.
Returns:
@@ -1037,21 +1042,44 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return images
@staticmethod
def rgblike_to_depthmap(image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
Args:
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.
Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
# 1. Cast the tensor to a larger integer type (e.g., int32)
# to safely perform the multiplication by 256.
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
# 3. Cast the final result to the desired depth map type (uint16) if needed
# before returning, though leaving it as int32/int64 is often safer
# for return value from a library function.
def numpy_to_depth(self, images: np.ndarray) -> list[PIL.Image.Image]:
if isinstance(image, torch.Tensor):
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
original_dtype = image.dtype
image_safe = image.to(torch.int32)
# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
# You may want to cast the final result to uint16, but casting to a
# larger int type (like int32) is sufficient to fix the overflow.
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
return depth_map.to(original_dtype)
elif isinstance(image, np.ndarray):
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
original_dtype = image.dtype
image_safe = image.astype(np.int32)
# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
return depth_map.astype(original_dtype)
else:
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy depth image or a batch of images to a list of PIL images.
@@ -1060,7 +1088,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The input NumPy array of depth images, which can be a single image or a batch.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy depth images.
"""
if images.ndim == 3:
@@ -1083,8 +1111,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[list[bool]] = None,
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
@@ -1093,7 +1121,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
@@ -1131,8 +1159,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def preprocess(
self,
rgb: torch.Tensor | PIL.Image.Image | np.ndarray,
depth: torch.Tensor | PIL.Image.Image | np.ndarray,
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
@@ -1153,7 +1181,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
Target resolution for resizing the images. If specified, overrides height and width.
Returns:
`tuple[torch.Tensor, torch.Tensor]`:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing the processed RGB and depth images as PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -1391,7 +1419,7 @@ class PixArtImageProcessor(VaeImageProcessor):
)
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> tuple[int, int]:
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
r"""
Returns the binned height and width based on the aspect ratio.
@@ -1401,7 +1429,7 @@ class PixArtImageProcessor(VaeImageProcessor):
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`tuple[int, int]`: The closest binned height and width.
`Tuple[int, int]`: The closest binned height and width.
"""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+2
View File
@@ -81,6 +81,7 @@ if is_torch_available():
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
+31 -31
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
@@ -57,15 +57,15 @@ class IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
subfolder: str | list[str],
weight_name: str | list[str],
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -74,10 +74,10 @@ class IPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
@@ -94,7 +94,7 @@ class IPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -358,14 +358,14 @@ class ModularIPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
subfolder: str | list[str],
weight_name: str | list[str],
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -374,10 +374,10 @@ class ModularIPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
@@ -387,7 +387,7 @@ class ModularIPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -608,9 +608,9 @@ class FluxIPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
weight_name: str | list[str],
subfolder: Optional[str | list[str]] = "",
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
@@ -618,7 +618,7 @@ class FluxIPAdapterMixin:
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -627,10 +627,10 @@ class FluxIPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
@@ -647,7 +647,7 @@ class FluxIPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -797,13 +797,13 @@ class FluxIPAdapterMixin:
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float | list[float] | list[list[float]]):
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `list[float]`
length match the number of blocks, it is repeated for each IP adapter. `list[list[float]]` must match the
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
@@ -823,18 +823,18 @@ class FluxIPAdapterMixin:
```
"""
scale_type = int | float
scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers
# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# list of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, list[scale_type]) and num_ip_adapters == 1:
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, list[scale_type | list[scale_type]]):
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
if len(scale) != num_ip_adapters:
@@ -918,7 +918,7 @@ class SD3IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
@@ -953,7 +953,7 @@ class SD3IPAdapterMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+27 -27
View File
@@ -17,7 +17,7 @@ import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
@@ -77,7 +77,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adap
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
"""
merge_kwargs = {"safe_merge": safe_fusing}
@@ -116,20 +116,20 @@ def unfuse_text_encoder_lora(text_encoder):
def set_adapters_for_text_encoder(
adapter_names: list[str] | str,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: Optional[float | list[float] | list[None]] = None,
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`list[float]`, *optional*):
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if text_encoder is None:
@@ -535,10 +535,10 @@ class LoraBaseMixin:
def fuse_lora(
self,
components: list[str] = [],
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -547,12 +547,12 @@ class LoraBaseMixin:
> [!WARNING] > This is an experimental API.
Args:
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`list[str]`, *optional*):
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
@@ -619,7 +619,7 @@ class LoraBaseMixin:
self._merged_adapters = self._merged_adapters | merged_adapter_names
def unfuse_lora(self, components: list[str] = [], **kwargs):
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -627,7 +627,7 @@ class LoraBaseMixin:
> [!WARNING] > This is an experimental API.
Args:
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
@@ -674,16 +674,16 @@ class LoraBaseMixin:
def set_adapters(
self,
adapter_names: list[str] | str,
adapter_weights: Optional[float | Dict | list[float] | list[Dict]] = None,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
"""
Set the currently active adapters for use in the pipeline.
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[list[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -835,12 +835,12 @@ class LoraBaseMixin:
elif issubclass(model.__class__, PreTrainedModel):
enable_lora_for_text_encoder(model)
def delete_adapters(self, adapter_names: list[str] | str):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the pipeline.
Args:
adapter_names (`Union[list[str], str]`):
adapter_names (`Union[List[str], str]`):
The names of the adapters to delete.
Example:
@@ -873,7 +873,7 @@ class LoraBaseMixin:
for adapter_name in adapter_names:
delete_adapter_layers(model, adapter_name)
def get_active_adapters(self) -> list[str]:
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
@@ -906,7 +906,7 @@ class LoraBaseMixin:
return active_adapters
def get_list_adapters(self) -> dict[str, list[str]]:
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
@@ -928,7 +928,7 @@ class LoraBaseMixin:
return set_adapters
def set_lora_device(self, adapter_names: list[str], device: torch.device | str | int) -> None:
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
@@ -955,8 +955,8 @@ class LoraBaseMixin:
```
Args:
adapter_names (`list[str]`):
list of adapters to send device to.
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
@@ -1007,7 +1007,7 @@ class LoraBaseMixin:
@staticmethod
def write_lora_layers(
state_dict: dict[str, torch.Tensor],
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
@@ -1059,9 +1059,9 @@ class LoraBaseMixin:
@classmethod
def _save_lora_weights(
cls,
save_directory: str | os.PathLike,
lora_layers: dict[str, dict[str, torch.nn.Module | torch.Tensor]],
lora_metadata: dict[str, Optional[dict]],
save_directory: Union[str, os.PathLike],
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
lora_metadata: Dict[str, Optional[dict]],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
+92 -1
View File
@@ -13,6 +13,7 @@
# limitations under the License.
import re
from typing import List
import torch
@@ -1020,7 +1021,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
return new_state_dict
def _custom_replace(key: str, substrings: list[str]) -> str:
def _custom_replace(key: str, substrings: List[str]) -> str:
# Replaces the "."s with "_"s upto the `substrings`.
# Example:
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
@@ -2212,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
state_dict = {convert_key(k): v for k, v in state_dict.items()}
has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
@@ -2260,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
converted_state_dict = {}
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
num_double_layers = 8
num_single_layers = 48
lora_keys = ("lora_A", "lora_B")
attn_types = ("img_attn", "txt_attn")
for sl in range(num_single_layers):
single_block_prefix = f"single_blocks.{sl}"
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
for lora_key in lora_keys:
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
diff_attn_proj_keys = (
["to_q", "to_k", "to_v"]
if attn_type == "img_attn"
else ["add_q_proj", "add_k_proj", "add_v_proj"]
)
for proj_key in diff_attn_proj_keys:
converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
[fused_qkv_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
if attn_type == "img_attn":
converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
else:
converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
proj_mappings = [
("img_attn.proj", "attn.to_out.0"),
("txt_attn.proj", "attn.to_add_out"),
]
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
("img_mlp.2", "ff.linear_out"),
("txt_mlp.0", "ff_context.linear_in"),
("txt_mlp.2", "ff_context.linear_out"),
]
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
File diff suppressed because it is too large Load Diff
+13 -12
View File
@@ -17,7 +17,7 @@ import json
import os
from functools import partial
from pathlib import Path
from typing import Dict, Literal, Optional
from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
}
@@ -113,7 +114,7 @@ class PeftAdapterMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -127,7 +128,7 @@ class PeftAdapterMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -447,16 +448,16 @@ class PeftAdapterMixin:
def set_adapters(
self,
adapter_names: list[str] | str,
weights: Optional[float | Dict | list[float] | list[Dict] | list[None]] = None,
adapter_names: Union[List[str], str],
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[list[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -539,7 +540,7 @@ class PeftAdapterMixin:
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: str | list[str]) -> None:
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
@@ -547,7 +548,7 @@ class PeftAdapterMixin:
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_name (Union[str, list[str]])):
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in the case of a single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
@@ -633,7 +634,7 @@ class PeftAdapterMixin:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> list[str]:
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
@@ -756,12 +757,12 @@ class PeftAdapterMixin:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def delete_adapters(self, adapter_names: list[str] | str):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the underlying model.
Args:
adapter_names (`Union[list[str], str]`):
adapter_names (`Union[List[str], str]`):
The names (single string or list of strings) of the adapter to delete.
Example:
+1 -1
View File
@@ -290,7 +290,7 @@ class FromSingleFileMixin:
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+6 -1
View File
@@ -34,6 +34,7 @@ from .single_file_utils import (
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux2_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -229,7 +234,7 @@ class FromOriginalModelMixin:
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+170
View File
@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = {
"net.blocks.0.self_attn.q_proj.weight",
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "animatediff_v3"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
model_type = "flux-2-dev"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
# Image and text input projections
"img_in": "x_embedder",
"txt_in": "context_embedder",
# Timestep and guidance embeddings
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
# Modulation parameters
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
# Final output layer
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
"final_layer.linear": "proj_out",
}
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear",
}
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
# Handle fused QKV projections separately as we need to break into Q, K, V projections
"img_attn.norm.query_norm": "attn.norm_q",
"img_attn.norm.key_norm": "attn.norm_k",
"img_attn.proj": "attn.to_out.0",
"img_mlp.0": "ff.linear_in",
"img_mlp.2": "ff.linear_out",
"txt_attn.norm.query_norm": "attn.norm_added_q",
"txt_attn.norm.key_norm": "attn.norm_added_k",
"txt_attn.proj": "attn.to_add_out",
"txt_mlp.0": "ff_context.linear_in",
"txt_mlp.2": "ff_context.linear_out",
}
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
"linear1": "attn.to_qkv_mlp_proj",
"linear2": "attn.to_out",
"norm.query_norm": "attn.norm_q",
"norm.key_norm": "attn.norm_k",
}
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
# Mapping:
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
new_prefix = "single_transformer_blocks"
if "single_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight
if ".weight" not in key:
return
# If adaLN_modulation is in the key, swap scale and shift parameters
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
if "adaLN_modulation" in key:
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
# Assume all such keys are in the AdaLayerNorm key map
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
new_key = ".".join([new_key_without_param_type, param_type])
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
state_dict[new_key] = swapped_weight
return
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
new_prefix = "transformer_blocks"
if "double_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
if "qkv" in within_block_name:
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
if "img" in modality_block_name:
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.to_q"
new_k_name = "attn.to_k"
new_v_name = "attn.to_v"
elif "txt" in modality_block_name:
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.add_q_proj"
new_k_name = "attn.add_k_proj"
new_v_name = "attn.add_v_proj"
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
state_dict[new_q_key] = to_q_weight
state_dict[new_k_key] = to_k_weight
state_dict[new_v_key] = to_v_weight
else:
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaLN_modulation": convert_ada_layer_norm_weights,
"double_blocks": convert_flux2_double_stream_blocks,
"single_blocks": convert_flux2_single_stream_blocks,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
+10 -10
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Dict, List, Optional, Union
import safetensors
import torch
@@ -112,7 +112,7 @@ class TextualInversionLoaderMixin:
Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: "PreTrainedTokenizer"): # noqa: F821
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
@@ -127,14 +127,14 @@ class TextualInversionLoaderMixin:
Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, list):
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, list):
if not isinstance(prompt, List):
return prompts[0]
return prompts
@@ -263,8 +263,8 @@ class TextualInversionLoaderMixin:
@validate_hf_hub_args
def load_textual_inversion(
self,
pretrained_model_name_or_path: str | list[str] | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]],
token: Optional[str | list[str]] = None,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs,
@@ -274,7 +274,7 @@ class TextualInversionLoaderMixin:
Automatic1111 formats are supported).
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `list[str or os.PathLike]` or `Dict` or `list[Dict]`):
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either one of the following or a list of them:
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
@@ -285,7 +285,7 @@ class TextualInversionLoaderMixin:
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
token (`str` or `list[str]`, *optional*):
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
@@ -306,7 +306,7 @@ class TextualInversionLoaderMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -458,7 +458,7 @@ class TextualInversionLoaderMixin:
def unload_textual_inversion(
self,
tokens: Optional[str | list[str]] = None,
tokens: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
text_encoder: Optional["PreTrainedModel"] = None,
):
+5 -5
View File
@@ -15,7 +15,7 @@ import os
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Callable
from typing import Callable, Dict, Union
import safetensors
import torch
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs):
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
@@ -92,7 +92,7 @@ class UNet2DConditionLoadersMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -106,7 +106,7 @@ class UNet2DConditionLoadersMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -412,7 +412,7 @@ class UNet2DConditionLoadersMixin:
def save_attn_procs(
self,
save_directory: str | os.PathLike,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
+9 -7
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Union
from torch import nn
@@ -40,7 +40,9 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn))
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[float | Dict], default_scale=1.0):
def _maybe_expand_lora_scales(
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
@@ -62,9 +64,9 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[
def _maybe_expand_lora_scales_for_one_adapter(
scales: float | Dict,
blocks_with_transformer: dict[str, int],
transformer_per_block: dict[str, int],
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
model: nn.Module,
default_scale: float = 1.0,
):
@@ -74,9 +76,9 @@ def _maybe_expand_lora_scales_for_one_adapter(
Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`dict[str, int]`):
blocks_with_transformer (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`dict[str, int]`):
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has
E.g. turns
+2 -1
View File
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: dict[str, torch.Tensor]):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
+12
View File
@@ -35,6 +35,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
@@ -86,11 +87,13 @@ if is_torch_available():
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
@@ -102,11 +105,14 @@ if is_torch_available():
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -137,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLFlux2,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
@@ -178,6 +185,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -186,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
@@ -204,14 +213,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageTransformer2DModel,
)
from .unets import (
I2VGenXLUNet,
+54 -30
View File
@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -79,29 +84,46 @@ class ContextParallelConfig:
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
@@ -119,7 +141,7 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
@@ -127,14 +149,14 @@ class ParallelConfig:
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)
@@ -187,17 +209,19 @@ class ContextParallelOutput:
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = dict[
str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...]
ContextParallelInputType = Dict[
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...]
ContextParallelOutputType = Union[
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType]
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):

Some files were not shown because too many files have changed in this diff Show More