Compare commits

...

110 Commits

Author SHA1 Message Date
sayakpaul ce43e58afe up 2025-11-27 16:34:35 +05:30
sayakpaul 51fbe6a1ed up 2025-11-27 16:24:15 +05:30
sayakpaul d456b5d925 up 2025-11-27 15:07:49 +05:30
sayakpaul b9be438a2a up 2025-11-27 14:43:24 +05:30
sayakpaul e3d6945bcf up 2025-11-27 14:22:53 +05:30
sayakpaul 01f488e6d9 Merge branch 'main' into qwen-pipeline-mixin 2025-11-27 14:16:20 +05:30
sayakpaul d8247cfbb3 up 2025-11-27 14:15:22 +05:30
sayakpaul 57cf1b4134 up 2025-11-27 13:28:03 +05:30
sayakpaul 8094f660cf up 2025-11-27 13:25:30 +05:30
sayakpaul 15e3a0fe60 Revert "p"
This reverts commit c6fc91031e.
2025-11-27 13:14:14 +05:30
Sayak Paul 6bf668c4d2 [chore] remove torch.save from remnant code. (#12717)
remove torch.save from remnant code.
2025-11-27 13:04:09 +05:30
sayakpaul c6fc91031e p 2025-11-27 12:41:29 +05:30
sayakpaul ffc95627a9 Merge branch 'main' into qwen-pipeline-mixin 2025-11-27 12:34:05 +05:30
sayakpaul b3b11b5fc2 up 2025-11-27 12:33:40 +05:30
sayakpaul 511c7a4c40 up 2025-11-27 12:04:29 +05:30
sayakpaul 1579a83d9a up 2025-11-27 12:04:08 +05:30
sayakpaul 64eaf85403 up 2025-11-27 12:01:22 +05:30
sayakpaul 6ac18d00e7 more 2025-11-27 11:53:43 +05:30
sayakpaul 5f8a9b61b5 up 2025-11-27 11:51:42 +05:30
sayakpaul abace0508f up 2025-11-27 11:49:53 +05:30
sayakpaul 2d4f144da4 up 2025-11-27 11:47:37 +05:30
sayakpaul 41c59213da stable diffusion og. 2025-11-27 11:44:44 +05:30
sayakpaul 755dc49d1b stable diffusion og. 2025-11-27 11:44:28 +05:30
sayakpaul 4bd7dd56b8 up 2025-11-27 10:58:03 +05:30
sayakpaul 89ebea4fbd up 2025-11-27 10:53:58 +05:30
sayakpaul b1a883517d Merge branch 'main' into qwen-pipeline-mixin 2025-11-27 10:47:30 +05:30
Jerry Wu e6d4612309 Support unittest for Z-image ️ (#12715)
* Add Support for Z-Image.

* Reformatting with make style, black & isort.

* Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.

* modified main model forward, freqs_cis left

* refactored to add B dim

* fixed stack issue

* fixed modulation bug

* fixed modulation bug

* fix bug

* remove value_from_time_aware_config

* styling

* Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.

* Replace padding with pad_sequence; Add gradient checkpointing.

* Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

* Fix Docstring and Make Style.

* Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7ed1.

* update z-image docstring

* Revert attention dispatcher

* update z-image docstring

* styling

* Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.

* Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.

* Remove einop dependency.

* remove redundant imports & make fix-copies

* fix import

* Support for num_images_per_prompt>1; Remove redundant unquote variables.

* Fix bugs for num_images_per_prompt with actual batch.

* Add unit tests for Z-Image.

* Refine unitest and skip for cases needed separate test env; Fix compatibility with unitest in model, mostly precision formating.

* Add clean env for test_save_load_float16 separ test; Add Note; Styling.

* Update dtype mentioned by yiyi.

---------

Co-authored-by: liudongyang <liudongyang0114@gmail.com>
2025-11-26 07:18:57 -10:00
David El Malih a88a7b4f03 Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)
* Improve docstrings and type hints in multiple diffusion schedulers

* docs: update Imagen Video paper link to Hugging Face Papers.
2025-11-26 08:38:41 -08:00
Sayak Paul c8656ed73c [docs] put autopipeline after overview and hunyuanimage in images (#12548)
put autopipeline after overview and hunyuanimage in images
2025-11-26 15:34:22 +05:30
Sayak Paul 94c9613f99 [docs] Correct flux2 links (#12716)
* fix links

* up
2025-11-26 10:46:51 +05:30
Sayak Paul b91e8c0d0b [lora]: Fix Flux2 LoRA NaN test (#12714)
* up

* Update tests/lora/test_lora_layers_flux2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-11-26 09:07:48 +05:30
Andrei Filatov ac7864624b Update script names in README for Flux2 training (#12713) 2025-11-26 07:02:18 +05:30
Sayak Paul 5ffb73d4ae let's go Flux2 🚀 (#12711)
* add vae

* Initial commit for Flux 2 Transformer implementation

* add pipeline part

* small edits to the pipeline and conversion

* update conversion script

* fix

* up up

* finish pipeline

* Remove Flux IP Adapter logic for now

* Remove deprecated 3D id logic

* Remove ControlNet logic for now

* Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block

* update pipeline

* Don't use biases for input projs and output AdaNorm

* up

* Remove bias for double stream block text QKV projections

* Add script to convert Flux 2 transformer to diffusers

* make style and make quality

* fix a few things.

* allow sft files to go.

* fix image processor

* fix batch

* style a bit

* Fix some bugs in Flux 2 transformer implementation

* Fix dummy input preparation and fix some test bugs

* fix dtype casting in timestep guidance module.

* resolve conflicts.,

* remove ip adapter stuff.

* Fix Flux 2 transformer consistency test

* Fix bug in Flux2TransformerBlock (double stream block)

* Get remaining Flux 2 transformer tests passing

* make style; make quality; make fix-copies

* remove stuff.

* fix type annotaton.

* remove unneeded stuff from tests

* tests

* up

* up

* add sf support

* Remove unused IP Adapter and ControlNet logic from transformer (#9)

* copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>

* up

* up

* up

* up

* up

* Refactor Flux2Attention into separate classes for double stream and single stream attention

* Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion

* Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False

* Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion

* Address review comments

* Update src/diffusers/pipelines/flux2/pipeline_flux2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* up

* Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12)

* up

* support ostris loras. (#13)

* up

* update schdule

* up

* up (#17)

* add training scripts (#16)

* add training scripts

Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>

* model cpu offload in validation.

* add flux.2 readme

* add img2img and tests

* cpu offload in log validation

* Apply suggestions from code review

* fix

* up

* fixes

* remove i2i training tests for now.

---------

Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: linoytsaban <linoy@huggingface.co>

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
Co-authored-by: Daniel Gu <dgu8957@gmail.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-10-53-87-203.ec2.internal>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: linoytsaban <linoy@huggingface.co>
2025-11-25 21:49:04 +05:30
Jerry Wu 4088e8a851 Add Support for Z-Image Series (#12703)
* Add Support for Z-Image.

* Reformatting with make style, black & isort.

* Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.

* modified main model forward, freqs_cis left

* refactored to add B dim

* fixed stack issue

* fixed modulation bug

* fixed modulation bug

* fix bug

* remove value_from_time_aware_config

* styling

* Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.

* Replace padding with pad_sequence; Add gradient checkpointing.

* Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

* Fix Docstring and Make Style.

* Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7ed1.

* update z-image docstring

* Revert attention dispatcher

* update z-image docstring

* styling

* Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.

* Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.

* Remove einop dependency.

* remove redundant imports & make fix-copies

* fix import

---------

Co-authored-by: liudongyang <liudongyang0114@gmail.com>
2025-11-25 05:50:00 -10:00
Junsong Chen d33d9f6715 fix typo in docs (#12675)
* fix typo in docs

* Update docs/source/en/api/pipelines/sana_video.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-11-24 19:42:16 -08:00
sq dde8754ba2 Fix variable naming typos in community FluxControlNetFillInpaintPipeline (#12701)
- Fixed variable naming typos (maskkk -> mask_fill, mask_imagee -> mask_image_fill, masked_imagee -> masked_image_fill, masked_image_latentsss -> masked_latents_fill)

These changes improve code readability without affecting functionality.
2025-11-24 15:16:11 -08:00
cdutr fbcd3ba6b2 [i8n-pt] Fix grammar and expand Portuguese documentation (#12598)
* Updates Portuguese documentation for Diffusers library

Enhances the Portuguese documentation with:
- Restructured table of contents for improved navigation
- Added placeholder page for in-translation content
- Refined language and improved readability in existing pages
- Introduced a new page on basic Stable Diffusion performance guidance

Improves overall documentation structure and user experience for Portuguese-speaking users

* Removes untranslated sections from Portuguese documentation

Cleans up the Portuguese documentation table of contents by removing placeholder sections marked as "Em tradução" (In translation)

Removes the in_translation.md file and associated table of contents entries for sections that are not yet translated, improving documentation clarity
2025-11-24 14:07:32 -08:00
Sayak Paul d176f61fcf [core] support sage attention + FA2 through kernels (#12439)
* up

* support automatic dispatch.

* disable compile support for now./

* up

* flash too.

* document.

* up

* up

* up

* up
2025-11-24 16:58:07 +05:30
Sayak Paul d13e6c08fd Merge branch 'main' into qwen-pipeline-mixin 2025-11-24 16:39:57 +05:30
sayakpaul 6c0d55de20 up 2025-11-24 16:39:44 +05:30
DefTruth 354d35adb0 bugfix: fix chrono-edit context parallel (#12660)
* bugfix: fix chrono-edit context parallel

* bugfix: fix chrono-edit context parallel

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Clean up comments in transformer_chronoedit.py

Removed unnecessary comments regarding parallelization in cross-attention.

* fix style

* fix qc

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-24 16:36:53 +05:30
sayakpaul db38c47807 up 2025-11-24 16:35:05 +05:30
sayakpaul 4839692df2 remove sdxl related duplications. 2025-11-24 16:19:33 +05:30
sayakpaul 54adb215a0 sdxl. 2025-11-24 16:19:13 +05:30
sayakpaul c8176bfe04 up 2025-11-24 14:52:45 +05:30
sayakpaul 9322997b24 remove some more. 2025-11-24 14:51:19 +05:30
sayakpaul ef67154217 copy 2025-11-24 14:31:01 +05:30
sayakpaul 7c9dc971ac up 2025-11-24 14:27:48 +05:30
sayakpaul c12a61f216 up 2025-11-24 14:13:29 +05:30
sayakpaul debafc6960 up 2025-11-24 14:08:47 +05:30
sayakpaul 8048623daf move some methods to pipeline specific stuff. 2025-11-24 13:58:57 +05:30
sayakpaul 72fc6ad797 Merge branch 'main' into qwen-pipeline-mixin 2025-11-24 13:46:22 +05:30
SwayStar123 544ba677dd Add FluxLoraLoaderMixin to Fibo pipeline (#12688)
Update pipeline_bria_fibo.py
2025-11-24 13:31:31 +05:30
David El Malih 6f1042e36c Improve docstrings and type hints in scheduling_lms_discrete.py (#12678)
* Enhance type hints and docstrings in LMSDiscreteScheduler class

Updated type hints for function parameters and return types to improve code clarity and maintainability. Enhanced docstrings for several methods, providing clearer descriptions of their functionality and expected arguments. Notable changes include specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.

* docs: Add specific paper reference to `_convert_to_karras` docstring.

* Refactor `_convert_to_karras` docstring in DPMSolverSDEScheduler to include detailed descriptions and a specific paper reference, enhancing clarity and documentation consistency.
2025-11-21 10:18:09 -08:00
Sayak Paul 35dd13c5a4 Merge branch 'main' into qwen-pipeline-mixin 2025-11-20 10:07:38 +05:30
Pratim Dasude d5da453de5 Community Pipeline: FluxFillControlNetInpaintPipeline for FLUX Fill-Based Inpainting with ControlNet (#12649)
* new flux fill controlnet inpaint pipline

* Delete src/diffusers/pipelines/flux/pipline_flux_fill_controlnet_Inpaint.py

deleting from main flux pipeline

* Fluc_fill_controlnet community pipline

* Update README.md

* Apply style fixes
2025-11-19 16:18:46 -03:00
David El Malih 15370f8412 Improve docstrings and type hints in scheduling_pndm.py (#12676)
* Enhance docstrings and type hints in PNDMScheduler class

- Updated parameter descriptions to include default values and specific types using Literal for better clarity.
- Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method.
- Added type hints for method return types to enhance code readability and maintainability.

* Refactor docstring in PNDMScheduler class to enhance clarity

- Simplified the explanation of the method for computing the previous sample from the current sample.
- Updated the reference to the PNDM paper for better accessibility.
- Removed redundant notation explanations to streamline the documentation.
2025-11-19 09:36:41 -08:00
Dhruv Nair a96b145304 [CI] Fix failing Pipeline CPU tests (#12681)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-19 21:19:24 +05:30
Dhruv Nair 6d8973ffe2 [CI] Fix indentation issue in workflow files (#12685)
update
2025-11-19 09:30:04 +05:30
Sayak Paul ab71f3c864 [core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used.

* up

* refactor according to Dhruv's ideas.

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: dn6 <dhruv@huggingface.co>

* up

---------

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-19 08:19:00 +05:30
Dhruv Nair b7df4a5387 [CI] Temporarily pin transformers (#12677)
* update

* update

* update

* update
2025-11-18 14:43:06 +05:30
dg845 67dc65e2e3 Revert AutoencoderKLWan's dim_mult default value back to list (#12640)
Revert dim_mult back to list and fix type annotation
2025-11-17 18:39:53 +05:30
Dhruv Nair 3579fdabf9 [CI] Make CI logs less verbose (#12674)
update
2025-11-17 14:23:09 +05:30
Junsong Chen 1afc21855e SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)
* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify;

* fix bug and run text/image-to-vidoe success;

* make style; quality; fix-copies;

* add sana image-to-video pipeline in markdown;

* add test case for sana image-to-video;

* make style;

* add a init file in sana-video test dir;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* minor update;

* fix bug and skip fp16 save test;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* add copied from for `encode_prompt`

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-17 00:23:34 -08:00
David Bertoin 0c35b580fe [PRX pipeline]: add 1024 resolution ratio bins (#12670)
add 1024 ratio bins
2025-11-17 10:37:40 +05:30
David Bertoin 01a56927f1 Rope in float32 for mps or npu compatibility (#12665)
rope in float32
2025-11-15 20:44:34 +05:30
dg845 a9e4883b6a Update Wan Animate Docs (#12658)
* Update the Wan Animate docs to reflect the most recent code

* Further explain input preprocessing and link to original Wan Animate preprocessing scripts
2025-11-14 16:06:22 -08:00
David El Malih 63dd601758 Improve docstrings and type hints in scheduling_euler_discrete.py (#12654)
* refactor: enhance type hints and documentation in EulerDiscreteScheduler

Updated type hints for function parameters and return types in the EulerDiscreteScheduler class to improve code clarity and maintainability. Enhanced docstrings for several methods to provide clearer descriptions of their functionality and expected arguments. This includes specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.

* refactor: enhance type hints and documentation across multiple schedulers

Updated type hints and improved docstrings in various scheduler classes, including CMStochasticIterativeScheduler, CosineDPMSolverMultistepScheduler, and others. This includes specifying parameter types, return types, and providing clearer descriptions of method functionalities. Notable changes include the addition of default values in the begin_index argument and enhanced explanations for noise addition methods. These improvements aim to enhance code clarity and maintainability across the scheduling module.

* refactor: update docstrings to clarify noise schedule construction

Revised docstrings across multiple scheduler classes to enhance clarity regarding the construction of noise schedules. Updated references to relevant papers, ensuring accurate citations for the methodologies used. This includes changes in DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, and others, improving documentation consistency and readability.
2025-11-14 15:12:24 -08:00
Dhruv Nair eeae0338e7 [Modular] Add Custom Blocks guide to doc (#12339)
* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/_toctree.yml

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-11-14 10:59:59 +05:30
David El Malih 3c1ca869d7 Improve docstrings and type hints in scheduling_ddpm.py (#12651)
* Enhance type hints and docstrings in scheduling_ddpm.py

- Added type hints for function parameters and return types across the DDPMScheduler class and related functions.
- Improved docstrings for clarity, including detailed descriptions of parameters and return values.
- Updated the alpha_transform_type and beta_schedule parameters to use Literal types for better type safety.
- Refined the _get_variance and previous_timestep methods with comprehensive documentation.

* Refactor docstrings and type hints in scheduling_ddpm.py

- Cleaned up whitespace in the rescale_zero_terminal_snr function.
- Enhanced the variance_type parameter in the DDPMScheduler class with improved formatting for better readability.
- Updated the docstring for the compute_variance method to maintain consistency and clarity in parameter descriptions and return values.

* Apply `make fix-copies`

* Refactor type hints across multiple scheduler files

- Updated type hints to include `Literal` for improved type safety in various scheduling files.
- Ensured consistency in type hinting for parameters and return types across the affected modules.
- This change enhances code clarity and maintainability.
2025-11-13 14:46:23 -08:00
David El Malih 6fe4a6ff8e Improve docstrings and type hints in scheduling_ddim.py (#12622)
* Improve docstrings and type hints in scheduling_ddim.py

- Add complete type hints for all function parameters
- Enhance docstrings to follow project conventions
- Add missing parameter descriptions

Fixes #9567

* Enhance docstrings and type hints in scheduling_ddim.py

- Update parameter types and descriptions for clarity
- Improve explanations in method docstrings to align with project standards
- Add optional annotations for parameters where applicable

* Refine type hints and docstrings in scheduling_ddim.py

- Update parameter types to use Literal for specific string options
- Enhance docstring descriptions for clarity and consistency
- Ensure all parameters have appropriate type annotations and defaults

* Apply review feedback on scheduling_ddim.py

- Replace "prevent singularities" with "avoid numerical instability" for better clarity
- Add backticks around `alpha_bar` variable name for consistent formatting
- Convert Imagen Video paper URLs to Hugging Face papers references

* Propagate changes using 'make fix-copies'

* Add missing Literal
2025-11-13 14:45:58 -08:00
Steven Liu 40de88af8c [docs] AutoModel (#12644)
* automodel

* fix
2025-11-13 08:43:24 -08:00
Steven Liu 6a2309b98d [utils] Update check_doc_toc (#12642)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-13 08:42:31 -08:00
Sayak Paul cd3bbe2910 skip autoencoderdl layerwise casting memory (#12647) 2025-11-13 12:56:22 +05:30
kaixuanliu 7a001c3ee2 adjust unit tests for test_save_load_float16 (#12500)
* adjust unit tests for wan pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* avoid adjusting common `get_dummy_components` API

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `form_pretrained` to `transformer` and `transformer_2`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-13 11:57:12 +05:30
dg845 d8e4805816 [WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526)
---------

Co-authored-by: Tolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
2025-11-12 16:52:31 -10:00
David El Malih 44c3101685 Improve docstrings and type hints in scheduling_amused.py (#12623)
* Improve docstrings and type hints in scheduling_amused.py

- Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk)
- Enhance AmusedSchedulerOutput with proper Optional typing
- Add comprehensive docstrings for AmusedScheduler class
- Improve __init__, set_timesteps, step, and add_noise methods
- Fix type hints to match documentation conventions
- All changes follow project standards from issue #9567

* Enhance type hints and docstrings in scheduling_amused.py

- Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types.
- Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device.
- Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters.
- Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency.

* Apply review feedback on scheduling_amused.py

- Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
2025-11-12 17:26:10 -08:00
YiYi Xu d6c63bb956 [modular] add a check (#12628)
* add

* fix
2025-11-12 07:59:18 -10:00
Steven Liu 2f44d63046 [docs] Update install instructions (#12626)
remove commit

Removed specific commit reference for installation instructions.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-11-12 09:21:24 -08:00
Quentin Gallouédec f3db38c1e7 ArXiv -> HF Papers (#12583)
* Update pipeline_skyreels_v2_i2v.py

* Update README.md

* Update torch_utils.py

* Update torch_utils.py

* Update guider_utils.py

* Update pipeline_ltx.py

* Update pipeline_bria.py

* Apply suggestion from @qgallouedec

* Update autoencoder_kl_qwenimage.py

* Update pipeline_prx.py

* Update pipeline_wan_vace.py

* Update pipeline_skyreels_v2.py

* Update pipeline_skyreels_v2_diffusion_forcing.py

* Update pipeline_bria_fibo.py

* Update pipeline_skyreels_v2_diffusion_forcing_i2v.py

* Update pipeline_ltx_condition.py

* Update pipeline_ltx_image2video.py

* Update regional_prompting_stable_diffusion.py

* make style

* style

* style
2025-11-12 08:37:21 -08:00
Sayak Paul f5e5f34823 [modular] add tests for qwen modular (#12585)
* add tests for qwenimage modular.

* qwenimage edit.

* qwenimage edit plus.

* empty

* align with the latest structure

* up

* up

* reason

* up

* fix multiple issues.

* up

* up

* fix

* up

* make it similar to the original pipeline.
2025-11-12 17:37:42 +05:30
YiYi Xu 093cd3f040 fix dispatch_attention_fn check (#12636)
* fix

* fix
2025-11-11 19:16:13 -10:00
a120092009 aecf0c53bf Add MLU Support. (#12629)
* Add MLU Support.

* fix comment.

* rename is_mlu_available to is_torch_mlu_available

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 19:15:26 -10:00
YiYi Xu 0c7589293b fix copies (#12637)
* fix

* remoce cocpies instead
2025-11-11 15:44:55 -10:00
Charchit Sharma ff263947ad Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------

Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
2025-11-11 11:45:36 -10:00
Dhruv Nair 66e6a0215f [CI] Remove unittest dependency from testing_utils.py (#12621)
* update

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 16:40:39 +05:30
Cesaryuan 5a47442f92 Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-10 13:57:52 -08:00
Dhruv Nair 8f6328c4a4 [Modular] Clean up docs (#12604)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-10 23:37:29 +05:30
Dhruv Nair 8d45f219d0 Fix Context Parallel validation checks (#12446)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-10 23:37:07 +05:30
Yashwant Bezawada 0fd58c7706 fix: correct import path for load_model_dict_into_meta in conversion scripts (#12616)
The function load_model_dict_into_meta was moved from modeling_utils.py to
model_loading_utils.py but the imports in the conversion scripts were not
updated, causing ImportError when running these scripts.

This fixes the import in 6 conversion scripts:
- scripts/convert_sd3_to_diffusers.py
- scripts/convert_stable_cascade_lite.py
- scripts/convert_stable_cascade.py
- scripts/convert_stable_audio.py
- scripts/convert_sana_to_diffusers.py
- scripts/convert_sana_controlnet_to_diffusers.py

Fixes #12606
2025-11-10 14:47:18 +05:30
Dhruv Nair 35d703310c [CI] Fix typo in uv install (#12618)
update
2025-11-10 13:22:46 +05:30
YiYi Xu b455dc94a2 [modular] wan! (#12611)
* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
2025-11-09 21:48:50 -10:00
Sayak Paul 8ca0fa8ea4 Merge branch 'main' into qwen-pipeline-mixin 2025-11-04 08:34:36 +05:30
Sayak Paul 8832deef93 Merge branch 'main' into qwen-pipeline-mixin 2025-10-26 00:08:19 +05:30
Sayak Paul 8885a13c9a Merge branch 'main' into qwen-pipeline-mixin 2025-10-13 10:50:58 +05:30
Sayak Paul c5e9a4a648 Merge branch 'main' into qwen-pipeline-mixin 2025-10-03 17:21:05 +05:30
Sayak Paul fc87f40e7a Merge branch 'main' into qwen-pipeline-mixin 2025-09-25 08:09:10 +05:30
sayakpaul 9fecdc973b up 2025-09-23 11:50:33 +05:30
sayakpaul a2b7de3611 Merge branch 'main' into qwen-pipeline-mixin 2025-09-23 11:43:23 +05:30
sayakpaul 5b3295ad48 apply to flux 2025-09-23 11:43:15 +05:30
sayakpaul 78f292ea77 propgate changes for qwenimagedit plus. 2025-09-22 09:39:04 +05:30
sayakpaul d684d4647f Merge branch 'main' into qwen-pipeline-mixin 2025-09-22 09:27:33 +05:30
Sayak Paul b2f0ff7454 Merge branch 'main' into qwen-pipeline-mixin 2025-09-20 13:34:58 +05:30
sayakpaul 435a8c02af up 2025-09-20 13:19:40 +05:30
Sayak Paul d7ef6a0104 Merge branch 'main' into qwen-pipeline-mixin 2025-09-19 17:46:06 +05:30
Sayak Paul c9a9559600 Merge branch 'main' into qwen-pipeline-mixin 2025-09-15 15:20:01 +05:30
sayakpaul 9df6c2f580 remove more. 2025-09-15 15:06:05 +05:30
sayakpaul 13cf2b0c28 up 2025-09-12 17:36:15 +05:30
sayakpaul 369847397e up 2025-09-12 17:05:14 +05:30
sayakpaul d47d60f3e6 up 2025-09-12 17:04:53 +05:30
441 changed files with 25904 additions and 13168 deletions
+21 -7
View File
@@ -73,6 +73,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -84,7 +86,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -126,6 +128,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -138,7 +142,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -151,7 +155,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v --make-reports=examples_torch_cuda \
--make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -190,6 +194,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -198,7 +204,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,6 +238,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -281,6 +289,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -293,7 +303,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -358,6 +368,8 @@ jobs:
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
uv pip install pytest-reportlog
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -405,6 +417,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip install -U bitsandbytes optimum_quanto
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -531,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -587,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
+3 -2
View File
@@ -109,7 +109,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -120,7 +121,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
+8 -6
View File
@@ -115,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -126,7 +127,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -134,7 +135,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
-k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -246,7 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -255,11 +257,11 @@ jobs:
- name: Run fast PyTorch LoRA tests with PEFT
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_peft_main \
tests/lora/
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
+18 -15
View File
@@ -1,4 +1,4 @@
name: Fast GPU Tests on PR
name: Fast GPU Tests on PR
on:
pull_request:
@@ -71,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -149,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
-k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
fi
- name: Failure short reports
if: ${{ failure() }}
@@ -201,7 +202,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -222,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
else
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -262,7 +264,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment
@@ -274,7 +277,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+11 -8
View File
@@ -76,7 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -87,7 +88,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -128,7 +129,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -141,7 +143,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -180,7 +182,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -189,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -230,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -273,7 +276,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
+1 -1
View File
@@ -57,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
+6 -6
View File
@@ -84,7 +84,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -137,7 +137,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -187,7 +187,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -240,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -281,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -326,7 +326,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+14 -4
View File
@@ -22,6 +22,8 @@
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
- local: using-diffusers/automodel
title: AutoModel
- local: using-diffusers/other-formats
title: Model formats
- local: using-diffusers/push_to_hub
@@ -119,6 +121,8 @@
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
title: Modular Diffusers
- isExpanded: false
sections:
@@ -345,6 +349,8 @@
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/flux2_transformer
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer
@@ -387,6 +393,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_animate_transformer_3d
title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
@@ -448,6 +456,8 @@
- sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- sections:
- local: api/pipelines/audioldm
title: AudioLDM
@@ -460,8 +470,6 @@
- local: api/pipelines/stable_audio
title: Stable Audio
title: Audio
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- sections:
- local: api/pipelines/amused
title: aMUSEd
@@ -519,12 +527,16 @@
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/flux2
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
@@ -638,8 +650,6 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
+6 -1
View File
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
## Flux2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
+1 -9
View File
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
# AutoModel
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
```python
from diffusers import AutoModel, AutoPipelineForText2Image
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
```
[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
## AutoModel
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2Transformer2DModel
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel
@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanAnimateTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanAnimateTransformer3DModel
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanAnimateTransformer3DModel
[[autodoc]] WanAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+33
View File
@@ -0,0 +1,33 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
> [!TIP]
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
>
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
- all
- __call__
+89 -2
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaVideoPipeline
# Sana-Video
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
@@ -37,6 +37,86 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Generation Pipelines
<hfoptions id="generation pipelines">`
<hfoption id="Text-to-Video">
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
```python
pipe = SanaVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
torch_dtype=torch.bfloat16,
)
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.float32)
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana_video.mp4", fps=16)
```
</hfoption>
<hfoption id="Image-to-Video">
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
```python
pipe = SanaImageToVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
motion_scale = 30.0
video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana-i2v.mp4", fps=16)
```
</hfoption>
</hfoptions>
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
@@ -97,6 +177,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
- __call__
## SanaImageToVideoPipeline
[[autodoc]] SanaImageToVideoPipeline
- all
- __call__
## SanaVideoPipelineOutput
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
+226 -17
View File
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -249,6 +250,208 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
The project page: https://humanaigc.github.io/wan-animate
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
#### Usage
The Wan-Animate pipeline supports two modes of operation:
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
##### Prerequisites
Before using the pipeline, you need to preprocess your reference video to extract:
- **Pose video**: Contains skeletal keypoints representing body motion
- **Face video**: Contains facial feature representations for expression control
For replacement mode, you additionally need:
- **Background video**: The original video containing the scene
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
> [!NOTE]
> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
The example below demonstrates how to use the Wan-Animate pipeline:
<hfoptions id="Animate usage">
<hfoption id="Animation mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load character image and preprocessed videos
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
# Resize image to match VAE constraints
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
# Generate animated video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
guidance_scale=1.0,
mode="animate", # Animation mode (default)
).frames[0]
export_to_video(output, "animated_character.mp4", fps=30)
```
</hfoption>
<hfoption id="Replacement mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load all required inputs for replacement mode
image = load_image("path/to/new_character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
background_video = load_video("path/to/background_video.mp4") # Original scene
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
# Resize image to match video dimensions
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
# Replace character in background video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_lengths=77,
guidance_scale=1.0,
mode="replace", # Replacement mode
).frames[0]
export_to_video(output, "character_replaced.mp4", fps=30)
```
</hfoption>
<hfoption id="Advanced options">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4")
face_video = load_video("path/to/face_video.mp4")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio"
negative_prompt = "blurry, low quality"
# Advanced: Use temporal guidance and custom callback
def callback_fn(pipe, step_index, timestep, callback_kwargs):
# You can modify latents or other tensors here
print(f"Step {step_index}, Timestep {timestep}")
return callback_kwargs
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
num_inference_steps=50,
guidance_scale=5.0,
prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
callback_on_step_end=callback_fn,
callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
export_to_video(output, "animated_advanced.mp4", fps=30)
```
</hfoption>
</hfoptions>
#### Key Parameters
- **mode**: Choose between `"animate"` (default) or `"replace"`
- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
@@ -281,10 +484,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
@@ -359,6 +562,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
## WanAnimatePipeline
[[autodoc]] WanAnimatePipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
@@ -0,0 +1,492 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Building Custom Blocks
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
> [!TIP]
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
## Project Structure
Your custom block project should use the following structure:
```shell
.
├── block.py
└── modular_config.json
```
- `block.py` contains the custom block implementation
- `modular_config.json` contains the metadata needed to load the block
## Example: Florence 2 Inpainting Block
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
```py
# Inside block.py
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ComponentSpec,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
```
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
```
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
def get_annotations(self, components, images, prompts, task):
task_prompts = [task + prompt for prompt in prompts]
inputs = components.image_annotator_processor(
text=task_prompts, images=images, return_tensors="pt"
).to(components.image_annotator.device, components.image_annotator.dtype)
generated_ids = components.image_annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = components.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
outputs = []
for image, annotation in zip(images, annotations):
outputs.append(
components.image_annotator_processor.post_process_generation(
annotation, task=task, image_size=(image.width, image.height)
)
)
return outputs
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
masks = []
for image, annotation in zip(images, annotations):
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
for _, _annotation in annotation.items():
if "polygons" in _annotation:
for polygon in _annotation["polygons"]:
polygon = np.array(polygon).reshape(-1, 2)
if len(polygon) < 3:
continue
polygon = polygon.reshape(-1).tolist()
draw.polygon(polygon, fill=fill)
elif "bbox" in _annotation:
bbox = _annotation["bbox"]
draw.rectangle(bbox, fill="white")
masks.append(mask_image)
return masks
def prepare_bounding_boxes(self, images, annotations):
outputs = []
for image, annotation in zip(images, annotations):
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for _, _annotation in annotation.items():
bbox = _annotation["bbox"]
label = _annotation["label"]
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
outputs.append(image_copy)
return outputs
def prepare_inputs(self, images, prompts):
prompts = prompts or ""
if isinstance(images, Image.Image):
images = [images]
if isinstance(prompts, str):
prompts = [prompts]
if len(images) != len(prompts):
raise ValueError("Number of images and annotation prompts must match.")
return images, prompts
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images, annotation_task_prompt = self.prepare_inputs(
block_state.image, block_state.annotation_prompt
)
task = block_state.annotation_task
fill = block_state.fill
annotations = self.get_annotations(
components, images, annotation_task_prompt, task
)
block_state.annotations = annotations
if block_state.annotation_output_type == "mask_image":
block_state.mask_image = self.prepare_mask(images, annotations)
else:
block_state.mask_image = None
if block_state.annotation_output_type == "mask_overlay":
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
elif block_state.annotation_output_type == "bounding_box":
block_state.image = self.prepare_bounding_boxes(images, annotations)
self.set_block_state(state, block_state)
return components, state
```
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
<hfoptions id="share">
<hfoption id="hf CLI">
```shell
# In the folder with the `block.py` file, run:
diffusers-cli custom_block
```
Then upload the block to the Hub:
```shell
hf upload <your repo id> . .
```
</hfoption>
<hfoption id="push_to_hub">
```py
from block import Florence2ImageAnnotatorBlock
block = Florence2ImageAnnotatorBlock()
block.push_to_hub("<your repo id>")
```
</hfoption>
</hfoptions>
## Using Custom Blocks
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
my_blocks = INPAINT_BLOCKS.copy()
# insert the annotation block before the image encoding step
my_blocks.insert("image_annotator", image_annotator_block, 1)
# Create our initial set of inpainting blocks
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
pipe = blocks.init_pipeline(repo_id)
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
image = image.resize((1024, 1024))
prompt = ["A red car"]
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
annotation_prompt = ["the car"]
output = pipe(
prompt=prompt,
image=image,
annotation_task=annotation_task,
annotation_prompt=annotation_prompt,
annotation_output_type="mask_image",
num_inference_steps=35,
guidance_scale=7.5,
strength=0.95,
output="images"
)
output[0].save("florence-inpainting.png")
```
## Editing Custom Blocks
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
```
Any changes made to the block files in this folder will be reflected when you load the block again.
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
- `__call__` method defines the loop structure and iteration logic.
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
```
```
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
]
```
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
Use `InputParam` to define `intermediate_inputs`.
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
@@ -76,7 +66,7 @@ def __call__(self, components, state):
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
@@ -112,4 +102,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
```
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
<hfoptions id="sequential">
<hfoption id="InputBlock">
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
```py
print(blocks)
print(blocks.doc)
```
```
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
@@ -0,0 +1,46 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoModel
The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
```py
import torch
from diffusers import AutoModel, DiffusionPipeline
transformer = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
)
text_encoder = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
)
```
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
+8 -6
View File
@@ -1,8 +1,10 @@
- sections:
- local: index
title: 🧨 Diffusers
- local: quicktour
title: Tour rápido
- local: installation
title: Instalação
- local: index
title: Diffusers
- local: installation
title: Instalação
- local: quicktour
title: Tour rápido
- local: stable_diffusion
title: Desempenho básico
title: Primeiros passos
+2 -2
View File
@@ -18,11 +18,11 @@ specific language governing permissions and limitations under the License.
# Diffusers
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
A Biblioteca tem três componentes principais:
- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
- Pipelines de última geração para a geração em poucas linhas de código. muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
+3 -3
View File
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
Um ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
Comece criando um ambiente virtual no diretório do projeto:
@@ -100,12 +100,12 @@ pip install -e ".[flax]"
</jax>
</frameworkcontent>
Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
Esses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
> [!WARNING]
> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
+132
View File
@@ -0,0 +1,132 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
[[open-in-colab]]
# Desempenho básico
Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
## Uso de memória
Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
pipeline.enable_model_cpu_offload()
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
pipeline(prompt).images[0]
print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
## Velocidade de inferência
O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
```py
import torch
import time
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
```py
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
start_time = time.perf_counter()
image = pipeline(prompt).images[0]
end_time = time.perf_counter()
print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
```
## Qualidade de geração
Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
negative_prompt = "low quality, blurry, ugly, poor details"
pipeline(prompt, negative_prompt=negative_prompt).images[0]
```
Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
```py
import torch
from diffusers import DiffusionPipeline, HeunDiscreteScheduler
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
negative_prompt = "low quality, blurry, ugly, poor details"
pipeline(prompt, negative_prompt=negative_prompt).images[0]
```
## Próximos passos
Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
+105 -2
View File
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
@@ -5527,3 +5527,106 @@ images = pipe(
).images
images[0].save("pizzeria.png")
```
# Flux Fill ControlNet Pipeline
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
## Example Usage
```python
import torch
from diffusers import (
FluxControlNetModel,
FluxPriorReduxPipeline,
)
from diffusers.utils import load_image
# NEW PIPELINE (updated name)
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Models
base_model = "black-forest-labs/FLUX.1-Fill-dev"
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
# Load ControlNet
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=dtype,
)
# Load Fill + ControlNet Pipeline
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=dtype,
).to(device)
# OPTIONAL FP8
# fill_pipe.transformer.enable_layerwise_casting(
# storage_dtype=torch.float8_e4m3fn,
# compute_dtype=torch.bfloat16
# )
# OPTIONAL Prior Redux
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
# prior_model,
# torch_dtype=dtype,
#).to(device)
# Inputs
# combined_image = load_image("person_input.png")
# 1. Prior conditioning
#prior_out = pipe_prior_redux(
# image=cloth_image,
# prompt=cloth_prompt,
#)
# 2. Fill Inpaint with ControlNet
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
img = load_image(r"imgs/background.jpg")
mask = load_image(r"imgs/mask.png")
control_image_depth = load_image(r"imgs/dog_depth _2.png")
result = fill_pipe(
prompt="a dog on a bench",
image=img,
mask_image=mask,
control_image=control_image_depth,
control_mode=[2], # union mode
control_guidance_start=0.0,
control_guidance_end=0.8,
controlnet_conditioning_scale=0.9,
height=1024,
width=1024,
strength=1.0,
guidance_scale=50.0,
num_inference_steps=60,
max_sequence_length=512,
# **prior_out,
)
# result.images[0].save("flux_fill_controlnet_inpaint.png")
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
```
@@ -489,7 +489,6 @@ class AdaptiveMaskInpaintPipeline(
# We'll offload the last model manually.
self.final_offload_hook = hook
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -651,7 +650,7 @@ class AdaptiveMaskInpaintPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -666,7 +665,7 @@ class AdaptiveMaskInpaintPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+5 -5
View File
@@ -1380,7 +1380,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
flow_model.eval()
self.flow_model = flow_model
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -1413,7 +1413,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -1672,7 +1672,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -1687,7 +1687,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
@@ -1699,7 +1699,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA")
image = image.convert("RGB")
@@ -277,7 +277,7 @@ class LatentConsistencyModelWalkPipeline(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -459,7 +459,7 @@ class LatentConsistencyModelWalkPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -525,7 +525,7 @@ class LatentConsistencyModelWalkPipeline(
assert emb.shape == (w.shape[0], embedding_dim)
return emb
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+11 -11
View File
@@ -1195,7 +1195,7 @@ class LLMGroundedDiffusionPipeline(
# Below are methods copied from StableDiffusionPipeline
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -1228,7 +1228,7 @@ class LLMGroundedDiffusionPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -1426,7 +1426,7 @@ class LLMGroundedDiffusionPipeline(
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -1441,7 +1441,7 @@ class LLMGroundedDiffusionPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
@@ -1453,7 +1453,7 @@ class LLMGroundedDiffusionPipeline(
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -1534,17 +1534,17 @@ class LLMGroundedDiffusionPipeline(
return emb
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.guidance_scale
def guidance_scale(self):
return self._guidance_scale
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.guidance_rescale
def guidance_rescale(self):
return self._guidance_rescale
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.clip_skip
def clip_skip(self):
return self._clip_skip
@@ -1552,16 +1552,16 @@ class LLMGroundedDiffusionPipeline(
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.do_classifier_free_guidance
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.cross_attention_kwargs
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.num_timesteps
def num_timesteps(self):
return self._num_timesteps
@@ -503,7 +503,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -518,7 +518,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -532,7 +532,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -721,7 +721,7 @@ class SDXLLongPromptWeightingPipeline(
# We'll offload the last model manually.
self.final_offload_hook = hook
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -941,7 +941,7 @@ class SDXLLongPromptWeightingPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+13 -8
View File
@@ -121,7 +121,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -136,7 +136,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+3 -3
View File
@@ -196,7 +196,7 @@ def _get_crops_coords_list(num_rows, num_cols, output_width):
return result_list
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
@@ -223,7 +223,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -627,7 +627,7 @@ class StableDiffusionXLTilingPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -244,7 +244,7 @@ def _tile2latent_indices(
return latent_row_init, latent_row_end, latent_col_init, latent_col_end
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -394,7 +394,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
COSINE = "Cosine"
GAUSSIAN = "Gaussian"
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -633,7 +633,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -194,7 +194,7 @@ class AnimateDiffControlNetPipeline(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
@@ -460,7 +460,7 @@ class AnimateDiffControlNetPipeline(
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -180,7 +180,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -194,7 +194,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -311,7 +311,7 @@ class AnimateDiffImgToVideoPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
@@ -577,7 +577,7 @@ class AnimateDiffImgToVideoPipeline(
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -165,7 +165,7 @@ class AnimateDiffPipelineIpex(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
@@ -438,7 +438,7 @@ class AnimateDiffPipelineIpex(
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -113,7 +113,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -458,7 +458,7 @@ class KolorsControlNetPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -501,7 +501,7 @@ class KolorsControlNetImg2ImgPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for others.
@@ -120,7 +120,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -134,7 +134,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -552,7 +552,7 @@ class KolorsControlNetInpaintPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -73,7 +73,7 @@ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
return blurred_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -379,7 +379,7 @@ class DemoFusionSDXLPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+1 -1
View File
@@ -186,7 +186,7 @@ class FabricPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -1073,7 +1078,7 @@ class LocalAttention:
return out
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -1096,7 +1101,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -1120,7 +1125,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -1500,7 +1505,7 @@ class FaithDiffStableDiffusionXLPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -95,7 +95,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -109,7 +109,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -502,7 +502,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
@@ -517,7 +517,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -126,7 +126,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -186,7 +186,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -567,7 +567,7 @@ class FluxKontextPipeline(
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
@@ -582,7 +582,7 @@ class FluxKontextPipeline(
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -89,7 +89,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -103,7 +103,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -86,7 +86,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -100,7 +100,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -640,7 +640,7 @@ class FluxSemanticGuidancePipeline(
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
@@ -655,7 +655,7 @@ class FluxSemanticGuidancePipeline(
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
+2 -2
View File
@@ -65,7 +65,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -79,7 +79,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -146,7 +146,7 @@ def get_resize_crop_region_for_grid(src, tgt_size):
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -161,7 +161,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor,
generator: Optional[torch.Generator] = None,
@@ -177,7 +177,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -512,7 +512,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
negative_prompt_attention_mask,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -527,7 +527,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -80,7 +80,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -458,7 +458,7 @@ class KolorsDifferentialImg2ImgPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -709,7 +709,7 @@ class KolorsDifferentialImg2ImgPipeline(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
@@ -94,7 +94,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -243,7 +243,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -257,7 +257,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -669,7 +669,7 @@ class KolorsInpaintPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+5 -5
View File
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -213,7 +213,7 @@ class Prompt2PromptPipeline(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -246,7 +246,7 @@ class Prompt2PromptPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -430,7 +430,7 @@ class Prompt2PromptPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -445,7 +445,7 @@ class Prompt2PromptPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -311,7 +311,7 @@ class SharedAttentionProcessor(AttnProcessor2_0):
return hidden_states
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -326,7 +326,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -371,7 +371,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -769,7 +769,7 @@ class StyleAlignedSDXLPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -80,7 +80,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -86,7 +86,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -100,7 +100,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -114,7 +114,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -386,7 +386,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -297,7 +297,7 @@ class AAS_XL(AttentionBase):
return out
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -439,7 +439,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -453,7 +453,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -696,7 +696,7 @@ class StableDiffusionXL_AE_Pipeline(
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -931,7 +931,7 @@ class StableDiffusionXL_AE_Pipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -137,7 +137,7 @@ def _preprocess_adapter_image(image, height, width):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -237,7 +237,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -475,7 +475,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -283,7 +283,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -384,7 +384,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -622,7 +622,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -103,7 +103,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -117,7 +117,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -268,7 +268,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
else:
self.watermark = None
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -506,7 +506,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -98,7 +98,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -113,7 +113,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -520,7 +520,7 @@ class StableDiffusionXLPipelineIpex(
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+2 -2
View File
@@ -138,7 +138,7 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -409,7 +409,7 @@ class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -142,7 +142,7 @@ def forward_without_stg(
return hidden_states, encoder_hidden_states
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
+2 -2
View File
@@ -119,7 +119,7 @@ def forward_with_stg(
return hidden_states
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -133,7 +133,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -124,7 +124,7 @@ def forward_with_stg(
return hidden_states
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -138,7 +138,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -198,7 +198,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
+1 -1
View File
@@ -137,7 +137,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
return sigma_schedule
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
File diff suppressed because it is too large Load Diff
@@ -490,7 +490,7 @@ class RegionalPromptingStableDiffusionPipeline(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -841,7 +841,7 @@ class RegionalPromptingStableDiffusionPipeline(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
@@ -872,7 +872,7 @@ class RegionalPromptingStableDiffusionPipeline(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1062,7 +1062,7 @@ class RegionalPromptingStableDiffusionPipeline(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
+2 -2
View File
@@ -234,7 +234,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
@@ -248,7 +248,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -338,7 +338,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
@@ -352,7 +352,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -358,7 +358,7 @@ class StableDiffusionReferencePipeline(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -408,7 +408,7 @@ class StableDiffusionReferencePipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt: Optional[str],
@@ -639,7 +639,7 @@ class StableDiffusionReferencePipeline(
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(
self, generator: Union[torch.Generator, List[torch.Generator]], eta: float
) -> Dict[str, Any]:
@@ -789,7 +789,7 @@ class StableDiffusionReferencePipeline(
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
return ref_image_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
@@ -281,7 +281,7 @@ class StableDiffusionRepaintPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -427,7 +427,7 @@ class StableDiffusionRepaintPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
@@ -438,7 +438,7 @@ class StableDiffusionRepaintPipeline(
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -456,7 +456,7 @@ class StableDiffusionRepaintPipeline(
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
@@ -832,7 +832,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
if "vae_encoder" in self.stages:
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
@@ -915,7 +915,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
@@ -788,7 +788,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
@@ -87,7 +87,7 @@ def torch_dfs(model: torch.nn.Module):
return result
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -61,7 +61,7 @@ def torch_dfs(model: torch.nn.Module):
return result
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -76,7 +76,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
+1 -2
View File
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
+315
View File
@@ -0,0 +1,315 @@
# DreamBooth training example for FLUX.2 [dev]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
### Remote Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2"
accelerate launch train_dreambooth_lora_flux2.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-lora"
accelerate launch train_dreambooth_lora_flux2.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
**important**
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
To start, you must have a dataset containing triplets:
* Condition image - the input image to be transformed.
* Target image - the desired output image after transformation.
* Instruction - a text prompt describing the transformation from the condition image to the target image.
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux2_img2img.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
--output_dir="flux2-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
### Aspect Ratio Bucketing
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
`
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
@@ -0,0 +1,262 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "dog"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
def test_dreambooth_lora_flux2(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--max_sequence_length 8
--checkpointing_steps=2
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1016,7 +1016,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
return new_string[:-nSpace]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
@@ -1124,7 +1124,7 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
return new_string[:-nSpace]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -1323,7 +1323,7 @@ class AnyTextPipeline(
return True
return False
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -1356,7 +1356,7 @@ class AnyTextPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -1610,7 +1610,7 @@ class AnyTextPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -1625,7 +1625,7 @@ class AnyTextPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
@@ -1637,7 +1637,7 @@ class AnyTextPipeline(
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -185,7 +185,7 @@ def get_closest_hw(width, height, image_size):
return width, height
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -457,7 +457,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -97,7 +97,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -285,7 +285,7 @@ class PromptDiffusionPipeline(
)
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -318,7 +318,7 @@ class PromptDiffusionPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
def encode_prompt(
self,
prompt,
@@ -525,7 +525,7 @@ class PromptDiffusionPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -540,7 +540,7 @@ class PromptDiffusionPipeline(
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
@@ -552,7 +552,7 @@ class PromptDiffusionPipeline(
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+475
View File
@@ -0,0 +1,475 @@
import argparse
from contextlib import nullcontext
from typing import Any, Dict, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils.import_utils import is_accelerate_available
"""
# VAE
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
--vae_filename "flux2-vae.sft" \
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
--vae
# DiT
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--dit \
--output_path .
# Full pipe
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--vae_filename "flux2-vae.sft" \
--dit --vae --full_pipe \
--output_path .
"""
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
parser.add_argument("--dit_dtype", type=str, default="bf16")
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--full_pipe", action="store_true")
parser.add_argument("--output_path", type=str)
args = parser.parse_args()
def load_original_checkpoint(args, filename):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
"quant_conv.weight": "encoder.quant_conv.weight",
"quant_conv.bias": "encoder.quant_conv.bias",
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
"bn.running_mean": "bn.running_mean",
"bn.running_var": "bn.running_var",
}
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
# proj_attn.weight has to be converted from conv 1D to linear
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
new_checkpoint = {}
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
if ldm_key not in vae_state_dict:
continue
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len(config["down_block_types"])
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
)
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.bias"
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len(config["up_block_types"])
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
)
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
# Image and text input projections
"img_in": "x_embedder",
"txt_in": "context_embedder",
# Timestep and guidance embeddings
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
# Modulation parameters
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
# Final output layer
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
"final_layer.linear": "proj_out",
}
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear",
}
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
# Handle fused QKV projections separately as we need to break into Q, K, V projections
"img_attn.norm.query_norm": "attn.norm_q",
"img_attn.norm.key_norm": "attn.norm_k",
"img_attn.proj": "attn.to_out.0",
"img_mlp.0": "ff.linear_in",
"img_mlp.2": "ff.linear_out",
"txt_attn.norm.query_norm": "attn.norm_added_q",
"txt_attn.norm.key_norm": "attn.norm_added_k",
"txt_attn.proj": "attn.to_add_out",
"txt_mlp.0": "ff_context.linear_in",
"txt_mlp.2": "ff_context.linear_out",
}
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
"linear1": "attn.to_qkv_mlp_proj",
"linear2": "attn.to_out",
"norm.query_norm": "attn.norm_q",
"norm.key_norm": "attn.norm_k",
}
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
# diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight
if ".weight" not in key:
return
# If adaLN_modulation is in the key, swap scale and shift parameters
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
if "adaLN_modulation" in key:
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
# Assume all such keys are in the AdaLayerNorm key map
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
new_key = ".".join([new_key_without_param_type, param_type])
swapped_weight = swap_scale_shift(state_dict.pop(key))
state_dict[new_key] = swapped_weight
return
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
new_prefix = "transformer_blocks"
if "double_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
if "qkv" in within_block_name:
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
if "img" in modality_block_name:
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.to_q"
new_k_name = "attn.to_k"
new_v_name = "attn.to_v"
elif "txt" in modality_block_name:
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.add_q_proj"
new_k_name = "attn.add_k_proj"
new_v_name = "attn.add_v_proj"
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
state_dict[new_q_key] = to_q_weight
state_dict[new_k_key] = to_k_weight
state_dict[new_v_key] = to_v_weight
else:
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
# Mapping:
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
new_prefix = "single_transformer_blocks"
if "single_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaLN_modulation": convert_ada_layer_norm_weights,
"double_blocks": convert_flux2_double_stream_blocks,
"single_blocks": convert_flux2_single_stream_blocks,
}
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "test" or model_type == "dummy-flux2":
config = {
"model_id": "diffusers-internal-dev/dummy-flux2",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 8,
"num_single_layers": 48,
"attention_head_dim": 128,
"num_attention_heads": 48,
"joint_attention_dim": 15360,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def main(args):
if args.vae:
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
vae = AutoencoderKLFlux2()
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if not args.full_pipe:
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
if args.full_pipe:
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
generate_config.do_sample = True
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
)
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
pipe = Flux2Pipeline(
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
pipe.save_pretrained(args.output_path)
if __name__ == "__main__":
main(args)
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
from diffusers import (
SanaControlNetModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -80,6 +80,8 @@ def main(args):
# scheduler
flow_shift = 8.0
if args.task == "i2v":
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
@@ -312,6 +314,7 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+1 -1
View File
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -18,7 +18,7 @@ from diffusers import (
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+265 -6
View File
@@ -6,11 +6,20 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
CLIPVisionModelWithProjection,
UMT5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"after_proj": "proj_out",
}
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"cross_attn.k_img": "attn2.to_k_img",
"cross_attn.v_img": "attn2.to_v_img",
"cross_attn.norm_k_img": "attn2.norm_k_img",
# After cross_attn -> attn2 rename, we need to rename the img keys
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
# Motion encoder mappings
# The name mapping is complicated for the convolutional part so we handle that in its own function
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
"face_encoder.conv2.conv": "face_encoder.conv2",
"face_encoder.conv3.conv": "face_encoder.conv3",
# Face adapter mappings are handled in a separate function
}
# TODO: Verify this and simplify if possible.
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
"""
Convert all motion encoder weights for Animate model.
In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
Conversion strategy:
1. Drop .kernel buffers (blur kernels)
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
"""
# Skip if not a weight, bias, or kernel
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
return
# Handle Blur kernel buffers from original implementation.
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
if ".kernel" in key and "motion_encoder" in key:
# Remove unexpected blur kernel buffers to avoid strict load errors
state_dict.pop(key, None)
return
# Rename Sequential indices to named components in ConvLayer and ResBlock
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
parts = key.split(".")
# Find the sequential index (digit) after convs or after conv1/conv2/skip
# Examples:
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
# - enc.net_app.convs.8 -> conv_out (final conv layer)
convs_idx = parts.index("convs") if "convs" in parts else -1
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
bias = False
# The nn.Sequential index will always follow convs
sequential_idx = int(parts[convs_idx + 1])
if sequential_idx == 0:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_in.weight"
elif key.endswith(".bias"):
new_key = "motion_encoder.conv_in.act_fn.bias"
bias = True
elif sequential_idx == final_conv_idx:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_out.weight"
else:
# Intermediate .convs. layers, which get mapped to .res_blocks.
prefix = "motion_encoder.res_blocks."
layer_name = parts[convs_idx + 2]
if layer_name == "skip":
layer_name = "conv_skip"
if key.endswith(".weight"):
param_name = "weight"
elif key.endswith(".bias"):
param_name = "act_fn.bias"
bias = True
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
param = state_dict.pop(key)
if bias:
param = param.squeeze()
state_dict[new_key] = param
return
return
return
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert face adapter weights for the Animate model.
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
"""
# Skip if not a weight or bias
if ".weight" not in key and ".bias" not in key:
return
prefix = "face_adapter."
if ".fuser_blocks." in key:
parts = key.split(".")
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
block_idx = parts[module_list_idx + 1]
layer_name = parts[module_list_idx + 2]
param_name = parts[module_list_idx + 3]
if layer_name == "linear1_kv":
layer_name_k = "to_k"
layer_name_v = "to_v"
suffix_k = ".".join([block_idx, layer_name_k, param_name])
suffix_v = ".".join([block_idx, layer_name_v, param_name])
new_key_k = prefix + suffix_k
new_key_v = prefix + suffix_v
kv_proj = state_dict.pop(key)
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
state_dict[new_key_k] = k_proj
state_dict[new_key_v] = v_proj
return
else:
if layer_name == "q_norm":
new_layer_name = "norm_q"
elif layer_name == "k_norm":
new_layer_name = "norm_k"
elif layer_name == "linear1_q":
new_layer_name = "to_q"
elif layer_name == "linear2":
new_layer_name = "to_out"
suffix_parts = [block_idx, new_layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
state_dict[new_key] = state_dict.pop(key)
return
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"motion_encoder": convert_animate_motion_encoder_weights,
"face_adapter": convert_animate_face_adapter_weights,
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": None,
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
"motion_style_dim": 512,
"motion_dim": 20,
"motion_encoder_dim": 512,
"face_encoder_hidden_dim": 1024,
"face_encoder_num_heads": 4,
"inject_face_latents_blocks": 5,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
# Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
# Move to CPU to ensure all tensors are materialized
transformer = transformer.to("cpu")
return transformer
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
+18
View File
@@ -186,6 +186,7 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLFlux2",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
@@ -215,6 +216,7 @@ else:
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"Flux2Transformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
@@ -268,8 +270,10 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageTransformer2DModel",
"attention_backend",
]
)
@@ -407,6 +411,7 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
]
@@ -455,6 +460,7 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -543,11 +549,13 @@ else:
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -635,6 +643,7 @@ else:
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
@@ -642,6 +651,7 @@ else:
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImagePipeline",
]
)
@@ -895,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLFlux2,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
@@ -924,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
@@ -976,6 +988,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
@@ -1090,6 +1103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
)
@@ -1134,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
@@ -1222,6 +1237,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,
@@ -1313,6 +1329,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
@@ -1320,6 +1337,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImagePipeline,
)
try:
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Some files were not shown because too many files have changed in this diff Show More