Compare commits

...

83 Commits

Author SHA1 Message Date
DN6 259b0c86dd update 2025-11-19 08:55:48 +05:30
Sayak Paul ab71f3c864 [core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used.

* up

* refactor according to Dhruv's ideas.

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: dn6 <dhruv@huggingface.co>

* up

---------

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-19 08:19:00 +05:30
Dhruv Nair b7df4a5387 [CI] Temporarily pin transformers (#12677)
* update

* update

* update

* update
2025-11-18 14:43:06 +05:30
dg845 67dc65e2e3 Revert AutoencoderKLWan's dim_mult default value back to list (#12640)
Revert dim_mult back to list and fix type annotation
2025-11-17 18:39:53 +05:30
Dhruv Nair 3579fdabf9 [CI] Make CI logs less verbose (#12674)
update
2025-11-17 14:23:09 +05:30
Junsong Chen 1afc21855e SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)
* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify;

* fix bug and run text/image-to-vidoe success;

* make style; quality; fix-copies;

* add sana image-to-video pipeline in markdown;

* add test case for sana image-to-video;

* make style;

* add a init file in sana-video test dir;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* minor update;

* fix bug and skip fp16 save test;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* add copied from for `encode_prompt`

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-17 00:23:34 -08:00
David Bertoin 0c35b580fe [PRX pipeline]: add 1024 resolution ratio bins (#12670)
add 1024 ratio bins
2025-11-17 10:37:40 +05:30
David Bertoin 01a56927f1 Rope in float32 for mps or npu compatibility (#12665)
rope in float32
2025-11-15 20:44:34 +05:30
dg845 a9e4883b6a Update Wan Animate Docs (#12658)
* Update the Wan Animate docs to reflect the most recent code

* Further explain input preprocessing and link to original Wan Animate preprocessing scripts
2025-11-14 16:06:22 -08:00
David El Malih 63dd601758 Improve docstrings and type hints in scheduling_euler_discrete.py (#12654)
* refactor: enhance type hints and documentation in EulerDiscreteScheduler

Updated type hints for function parameters and return types in the EulerDiscreteScheduler class to improve code clarity and maintainability. Enhanced docstrings for several methods to provide clearer descriptions of their functionality and expected arguments. This includes specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.

* refactor: enhance type hints and documentation across multiple schedulers

Updated type hints and improved docstrings in various scheduler classes, including CMStochasticIterativeScheduler, CosineDPMSolverMultistepScheduler, and others. This includes specifying parameter types, return types, and providing clearer descriptions of method functionalities. Notable changes include the addition of default values in the begin_index argument and enhanced explanations for noise addition methods. These improvements aim to enhance code clarity and maintainability across the scheduling module.

* refactor: update docstrings to clarify noise schedule construction

Revised docstrings across multiple scheduler classes to enhance clarity regarding the construction of noise schedules. Updated references to relevant papers, ensuring accurate citations for the methodologies used. This includes changes in DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, and others, improving documentation consistency and readability.
2025-11-14 15:12:24 -08:00
Dhruv Nair eeae0338e7 [Modular] Add Custom Blocks guide to doc (#12339)
* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/_toctree.yml

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

* update

* Update docs/source/en/modular_diffusers/custom_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-11-14 10:59:59 +05:30
David El Malih 3c1ca869d7 Improve docstrings and type hints in scheduling_ddpm.py (#12651)
* Enhance type hints and docstrings in scheduling_ddpm.py

- Added type hints for function parameters and return types across the DDPMScheduler class and related functions.
- Improved docstrings for clarity, including detailed descriptions of parameters and return values.
- Updated the alpha_transform_type and beta_schedule parameters to use Literal types for better type safety.
- Refined the _get_variance and previous_timestep methods with comprehensive documentation.

* Refactor docstrings and type hints in scheduling_ddpm.py

- Cleaned up whitespace in the rescale_zero_terminal_snr function.
- Enhanced the variance_type parameter in the DDPMScheduler class with improved formatting for better readability.
- Updated the docstring for the compute_variance method to maintain consistency and clarity in parameter descriptions and return values.

* Apply `make fix-copies`

* Refactor type hints across multiple scheduler files

- Updated type hints to include `Literal` for improved type safety in various scheduling files.
- Ensured consistency in type hinting for parameters and return types across the affected modules.
- This change enhances code clarity and maintainability.
2025-11-13 14:46:23 -08:00
David El Malih 6fe4a6ff8e Improve docstrings and type hints in scheduling_ddim.py (#12622)
* Improve docstrings and type hints in scheduling_ddim.py

- Add complete type hints for all function parameters
- Enhance docstrings to follow project conventions
- Add missing parameter descriptions

Fixes #9567

* Enhance docstrings and type hints in scheduling_ddim.py

- Update parameter types and descriptions for clarity
- Improve explanations in method docstrings to align with project standards
- Add optional annotations for parameters where applicable

* Refine type hints and docstrings in scheduling_ddim.py

- Update parameter types to use Literal for specific string options
- Enhance docstring descriptions for clarity and consistency
- Ensure all parameters have appropriate type annotations and defaults

* Apply review feedback on scheduling_ddim.py

- Replace "prevent singularities" with "avoid numerical instability" for better clarity
- Add backticks around `alpha_bar` variable name for consistent formatting
- Convert Imagen Video paper URLs to Hugging Face papers references

* Propagate changes using 'make fix-copies'

* Add missing Literal
2025-11-13 14:45:58 -08:00
Steven Liu 40de88af8c [docs] AutoModel (#12644)
* automodel

* fix
2025-11-13 08:43:24 -08:00
Steven Liu 6a2309b98d [utils] Update check_doc_toc (#12642)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-13 08:42:31 -08:00
Sayak Paul cd3bbe2910 skip autoencoderdl layerwise casting memory (#12647) 2025-11-13 12:56:22 +05:30
kaixuanliu 7a001c3ee2 adjust unit tests for test_save_load_float16 (#12500)
* adjust unit tests for wan pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* avoid adjusting common `get_dummy_components` API

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `form_pretrained` to `transformer` and `transformer_2`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-13 11:57:12 +05:30
dg845 d8e4805816 [WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526)
---------

Co-authored-by: Tolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
2025-11-12 16:52:31 -10:00
David El Malih 44c3101685 Improve docstrings and type hints in scheduling_amused.py (#12623)
* Improve docstrings and type hints in scheduling_amused.py

- Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk)
- Enhance AmusedSchedulerOutput with proper Optional typing
- Add comprehensive docstrings for AmusedScheduler class
- Improve __init__, set_timesteps, step, and add_noise methods
- Fix type hints to match documentation conventions
- All changes follow project standards from issue #9567

* Enhance type hints and docstrings in scheduling_amused.py

- Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types.
- Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device.
- Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters.
- Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency.

* Apply review feedback on scheduling_amused.py

- Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
2025-11-12 17:26:10 -08:00
YiYi Xu d6c63bb956 [modular] add a check (#12628)
* add

* fix
2025-11-12 07:59:18 -10:00
Steven Liu 2f44d63046 [docs] Update install instructions (#12626)
remove commit

Removed specific commit reference for installation instructions.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-11-12 09:21:24 -08:00
Quentin Gallouédec f3db38c1e7 ArXiv -> HF Papers (#12583)
* Update pipeline_skyreels_v2_i2v.py

* Update README.md

* Update torch_utils.py

* Update torch_utils.py

* Update guider_utils.py

* Update pipeline_ltx.py

* Update pipeline_bria.py

* Apply suggestion from @qgallouedec

* Update autoencoder_kl_qwenimage.py

* Update pipeline_prx.py

* Update pipeline_wan_vace.py

* Update pipeline_skyreels_v2.py

* Update pipeline_skyreels_v2_diffusion_forcing.py

* Update pipeline_bria_fibo.py

* Update pipeline_skyreels_v2_diffusion_forcing_i2v.py

* Update pipeline_ltx_condition.py

* Update pipeline_ltx_image2video.py

* Update regional_prompting_stable_diffusion.py

* make style

* style

* style
2025-11-12 08:37:21 -08:00
Sayak Paul f5e5f34823 [modular] add tests for qwen modular (#12585)
* add tests for qwenimage modular.

* qwenimage edit.

* qwenimage edit plus.

* empty

* align with the latest structure

* up

* up

* reason

* up

* fix multiple issues.

* up

* up

* fix

* up

* make it similar to the original pipeline.
2025-11-12 17:37:42 +05:30
YiYi Xu 093cd3f040 fix dispatch_attention_fn check (#12636)
* fix

* fix
2025-11-11 19:16:13 -10:00
a120092009 aecf0c53bf Add MLU Support. (#12629)
* Add MLU Support.

* fix comment.

* rename is_mlu_available to is_torch_mlu_available

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 19:15:26 -10:00
YiYi Xu 0c7589293b fix copies (#12637)
* fix

* remoce cocpies instead
2025-11-11 15:44:55 -10:00
Charchit Sharma ff263947ad Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------

Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
2025-11-11 11:45:36 -10:00
Dhruv Nair 66e6a0215f [CI] Remove unittest dependency from testing_utils.py (#12621)
* update

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 16:40:39 +05:30
Cesaryuan 5a47442f92 Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-10 13:57:52 -08:00
Dhruv Nair 8f6328c4a4 [Modular] Clean up docs (#12604)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-10 23:37:29 +05:30
Dhruv Nair 8d45f219d0 Fix Context Parallel validation checks (#12446)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-10 23:37:07 +05:30
Yashwant Bezawada 0fd58c7706 fix: correct import path for load_model_dict_into_meta in conversion scripts (#12616)
The function load_model_dict_into_meta was moved from modeling_utils.py to
model_loading_utils.py but the imports in the conversion scripts were not
updated, causing ImportError when running these scripts.

This fixes the import in 6 conversion scripts:
- scripts/convert_sd3_to_diffusers.py
- scripts/convert_stable_cascade_lite.py
- scripts/convert_stable_cascade.py
- scripts/convert_stable_audio.py
- scripts/convert_sana_to_diffusers.py
- scripts/convert_sana_controlnet_to_diffusers.py

Fixes #12606
2025-11-10 14:47:18 +05:30
Dhruv Nair 35d703310c [CI] Fix typo in uv install (#12618)
update
2025-11-10 13:22:46 +05:30
YiYi Xu b455dc94a2 [modular] wan! (#12611)
* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
2025-11-09 21:48:50 -10:00
Jay Wu 04f9d2bf3d add ChronoEdit (#12593)
* add ChronoEdit

* add ref to  original function & remove wan2.2 logics

* Update src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* add ChronoeEdit test

* add docs

* add docs

* make fix-copies

* fix chronoedit test

---------

Co-authored-by: wjay <wjay@nvidia.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-09 22:07:00 -08:00
Dhruv Nair bc8fd864eb [CI] Push test fix (#12617)
update
2025-11-10 09:26:14 +05:30
Wang, Yi a9cb08af39 fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled (#12562)
* fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* address review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-11-07 20:00:13 +05:30
DefTruth 9f669e7b5d feat: enable attention dispatch for huanyuan video (#12591)
* feat: enable attention dispatch for huanyuan video

* feat: enable attention dispatch for huanyuan video
2025-11-07 11:22:41 +05:30
Dhruv Nair 8ac17cd2cb [Modular] Some clean up for Modular tests (#12579)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-07 08:19:15 +05:30
Mohammad Sadegh Salehi e4393fa613 Fix overflow and dtype handling in rgblike_to_depthmap (NumPy + PyTorch) (#12546)
* Fix overflow in rgblike_to_depthmap by safe dtype casting (torch & NumPy)

* Fix: store original dtype and cast back after safe computation

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-06 08:18:21 -10:00
Junsong Chen b3e9dfced7 [SANA-Video] Adding 5s pre-trained 480p SANA-Video inference (#12584)
* 1. add `SanaVideoTransformer3DModel` in transformer_sana_video.py
2. add `SanaVideoPipeline` in pipeline_sana_video.py
3. add all code we need for import `SanaVideoPipeline`

* add a sample about how to use sana-video;

* code update;

* update hf model path;

* update code;

* sana-video can run now;

* 1. add aspect ratio in sana-video-pipeline;
2. add reshape function in sana-video-processor;
3. fix convert pth to safetensor bugs;

* default to use `use_resolution_binning`;

* make style;

* remove unused code;

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* support `dispatch_attention_fn`

* 1. add sana-video markdown;
2. fix typos;

* add two test case for sana-video (need check)

* fix text-encoder in test-sana-video;

* Update tests/pipelines/sana/test_sana_video.py

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana/test_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana_video.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/video_processor.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* make style
make quality
make fix-copies

* toctree yaml update;

* add sana-video-transformer3d markdown;

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-05 21:08:47 -08:00
Joseph Turian 58f3771545 Add optional precision-preserving preprocessing for examples/unconditional_image_generation/train_unconditional.py (#12596)
* Add optional precision-preserving preprocessing

* Document decoder caveat for precision flag

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-06 09:37:31 +05:30
Dhruv Nair 6198f8a12b [Modular] Allow ModularPipeline to load from revisions (#12592)
* update

* update

* update

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-06 07:54:24 +05:30
Linoy Tsaban dcfb18a2d3 [LoRA] add support for more Qwen LoRAs (#12581)
* fix bug when offload and cache_latents both enabled

* fix
2025-11-04 14:27:25 +02:00
Sayak Paul ac5a1e28fc [docs] sort doc (#12586)
sort doc
2025-11-04 10:26:07 +05:30
Lev Novitskiy 325a95051b Kandinsky 5.0 Docs fixes (#12582)
* add transformer pipeline first version

* updates

* fix 5sec generation

* rewrite Kandinsky5T2VPipeline to diffusers style

* add multiprompt support

* remove prints in pipeline

* add nabla attention

* Wrap Transformer in Diffusers style

* fix license

* fix prompt type

* add gradient checkpointing and peft support

* add usage example

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* remove unused imports

* add 10 second models support

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove no_grad and simplified prompt paddings

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved template to __init__

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved sdps inside processor

* remove oneline function

* remove reset_dtype methods

* Transformer: move all methods to forward

* separated prompt encoding

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fixed

* style +copies

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Charles <charles@huggingface.co>

* more

* Apply suggestions from code review

* add lora loader doc

* add compiled Nabla Attention

* all needed changes for 10 sec models are added!

* add docs

* Apply style fixes

* update docs

* add kandinsky5 to toctree

* add tests

* fix tests

* Apply style fixes

* update tests

* minor docs refactoring

* refactor Kandinsky 5.0 Vide docs

* Update docs/source/en/_toctree.yml

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Charles <charles@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-03 14:38:07 -10:00
Wang, Yi 1ec28a2c77 ulysses enabling in native attention path (#12563)
* ulysses enabling in native attention path

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* address review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add supports_context_parallel for native attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update templated attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-03 11:48:20 -10:00
YiYi Xu de6173c683 [modular]pass hub_kwargs to load_config (#12577)
pass hub_kwargs to load_config
2025-11-03 09:44:42 -10:00
Sayak Paul 8f80dda193 [tests] add tests for flux modular (t2i, i2i, kontext) (#12566)
* start flux modular tests.

* up

* add kontext

* up

* up

* up

* Update src/diffusers/modular_pipelines/flux/denoise.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* up

* up

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-02 10:51:11 +05:30
YiYi Xu cdbf0ad883 [modular] better warn message (#12573)
better warn message
2025-11-01 18:45:09 -10:00
Dhruv Nair 5e8415a311 Fix custom code loading in Automodel (#12571)
update
2025-11-01 17:04:31 -10:00
Friedrich Schöller 051c8a1c0f Fix Stable Diffusion 3.x pooled prompt embedding with multiple images (#12306) 2025-10-31 10:25:13 -10:00
Dhruv Nair d54622c267 [Modular] Allow custom blocks to be saved to local_dir (#12381)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-10-31 13:47:02 +05:30
Dhruv Nair df8dd77817 [Modular] Fix for custom block kwargs (#12561)
update
2025-10-31 00:14:24 +05:30
Pavle Padjin 9f3c0fdcd8 Avoiding graph break by changing the way we infer dtype in vae.decoder (#12512)
* Changing the way we infer dtype to avoid force evaluation of lazy tensors

* changing way to infer dtype to ensure type consistency

* more robust infering of dtype

* removing the upscale dtype entirely
2025-10-30 08:39:40 +05:30
galbria 84e16575e4 Bria fibo (#12545)
* Bria FIBO pipeline

* style fixs

* fix CR

* Refactor BriaFibo classes and update pipeline parameters

- Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents.
- Modified the _unpack_latents method in BriaFiboPipeline to improve clarity.
- Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching.
- Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests.

* edit the docs of FIBO

* Remove unused BriaFibo imports and update CPU offload method in BriaFiboPipeline

* Refactor FIBO classes to BriaFibo naming convention

- Updated class names from FIBO to BriaFibo for consistency across the module.
- Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming.
- Ensured all references in the BriaFiboTransformer2DModel are updated accordingly.

* Add BriaFiboTransformer2DModel import to transformers module

* Remove unused BriaFibo imports from modular pipelines and add BriaFiboTransformer2DModel and BriaFiboPipeline classes to dummy objects for enhanced compatibility with torch and transformers.

* Update BriaFibo classes with copied documentation and fix import typo in pipeline module

- Added documentation comments indicating the source of copied code in BriaFiboTransformerBlock and _pack_latents methods.
- Corrected the import statement for BriaFiboPipeline in the pipelines module.

* Remove unused BriaFibo imports from __init__.py to streamline modular pipelines.

* Refactor documentation comments in BriaFibo classes to indicate inspiration from existing implementations

- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to reflect that the code is inspired by other modules rather than copied.
- Enhanced clarity on the origins of the methods to maintain proper attribution.

* change Inspired by to Based on

* add reference link and fix trailing whitespace

* Add BriaFiboTransformer2DModel documentation and update comments in BriaFibo classes

- Introduced a new documentation file for BriaFiboTransformer2DModel.
- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to clarify the origins of the code, indicating copied sources for better attribution.

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2025-10-28 16:27:48 +05:30
Sayak Paul 55d49d4379 [ci] don't run sana layerwise casting tests in CI. (#12551)
* don't run sana layerwise casting tests in CI.

* up
2025-10-28 13:29:51 +05:30
Meatfucker 40528e9ae7 Fix typos in kandinsky5 docs (#12552)
Update kandinsky5.md

Fix typos
2025-10-28 02:54:24 -03:00
Wang, Yi dc622a95d0 fix crash if tiling mode is enabled (#12521)
* fix crash in tiling mode is enabled

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fmt

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-10-27 17:59:20 -10:00
Dhruv Nair ecfbc8f952 [Pipelines] Enable Wan VACE to run since single transformer (#12428)
* update

* update

* update

* update

* update
2025-10-28 09:21:31 +05:30
Sayak Paul df0e2a4f2c support latest few-step wan LoRA. (#12541)
* support latest few-step wan LoRA.

* up

* up
2025-10-28 08:55:24 +05:30
G.O.D 303efd2b8d Improve pos embed for Flux.1 inference on Ascend NPU (#12534)
improve pos embed for ascend npu

Co-authored-by: felix01.yu <felix01.yu@vipshop.com>
2025-10-27 16:55:36 -10:00
Lev Novitskiy 5afbcce176 Kandinsky 5 10 sec (NABLA suport) (#12520)
* add transformer pipeline first version

* updates

* fix 5sec generation

* rewrite Kandinsky5T2VPipeline to diffusers style

* add multiprompt support

* remove prints in pipeline

* add nabla attention

* Wrap Transformer in Diffusers style

* fix license

* fix prompt type

* add gradient checkpointing and peft support

* add usage example

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* remove unused imports

* add 10 second models support

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove no_grad and simplified prompt paddings

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved template to __init__

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved sdps inside processor

* remove oneline function

* remove reset_dtype methods

* Transformer: move all methods to forward

* separated prompt encoding

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fixed

* style +copies

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Charles <charles@huggingface.co>

* more

* Apply suggestions from code review

* add lora loader doc

* add compiled Nabla Attention

* all needed changes for 10 sec models are added!

* add docs

* Apply style fixes

* update docs

* add kandinsky5 to toctree

* add tests

* fix tests

* Apply style fixes

* update tests

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Charles <charles@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-28 07:47:18 +05:30
alirezafarashah 6d1a648602 Fix small inconsistency in output dimension of "_get_t5_prompt_embeds" function in sd3 pipeline (#12531)
* Fix small inconsistency in output dimension of t5 embeds when text_encoder_3 is None

* first commit

---------

Co-authored-by: Alireza Farashah <alireza.farashah@cn-g017.server.mila.quebec>
Co-authored-by: Alireza Farashah <alireza.farashah@login-2.server.mila.quebec>
2025-10-27 07:16:43 -10:00
Mikko Lauri 250f5cb53d Add AITER attention backend (#12549)
* add aiter attention backend

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-27 20:25:02 +05:30
josephrocca dc6bd1511a Fix Chroma attention padding order and update docs to use lodestones/Chroma1-HD (#12508)
* [Fix] Move attention mask padding after T5 embedding

* [Fix] Move attention mask padding after T5 embedding

* Clean up whitespace in pipeline_chroma.py

Removed unnecessary blank lines for cleaner code.

* Fix

* Fix

* Update model to final Chroma1-HD checkpoint

* Update to Chroma1-HD

* Update model to Chroma1-HD

* Update model to Chroma1-HD

* Update Chroma model links to Chroma1-HD

* Add comment about padding/masking

* Fix checkpoint/repo references

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-10-27 16:25:20 +05:30
Sayak Paul 500b9cf184 [chore] Move guiders experimental warning (#12543)
* move guiders experimental warning to init.

* up
2025-10-26 07:41:23 -10:00
Dhruv Nair d34b18c783 Deprecate Stable Cascade (#12537)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-10-24 22:06:31 +05:30
kaixuanliu 7536f647e4 Loose the criteria tolerance appropriately for Intel XPU devices (#12460)
* Loose the criteria tolerance appropriately for Intel XPU devices

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* change back the atol value

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use expectations

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Update tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
2025-10-24 12:18:15 +02:00
YiYi Xu a138d71ec1 HunyuanImage21 (#12333)
* add hunyuanimage2.1


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-10-23 22:31:12 -10:00
Sayak Paul bc4039886d fix constants.py to user upper() (#12479) 2025-10-24 12:00:02 +05:30
Dhruv Nair 9c3b58dcf1 Handle deprecated transformer classes (#12517)
* update

* update

* update
2025-10-23 16:22:07 +05:30
Aishwarya Badlani 74b5fed434 Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432 (#12449)
* Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432

* Fix trailing whitespace in docstring

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-23 16:18:07 +05:30
kaixuanliu 85eb505672 fix CI bug for kandinsky3_img2img case (#12474)
* fix CI bug for kandinsky3_img2img case

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2025-10-23 16:17:22 +05:30
Sayak Paul ccdd96ca52 [tests] Test attention backends (#12388)
* add a lightweight test suite for attention backends.

* up

* up

* Apply suggestions from code review

* formatting
2025-10-23 15:09:41 +05:30
Sayak Paul 4c723d8ec3 [CI] xfail the test_wuerstchen_prior test (#12530)
xfail the test_wuerstchen_prior test
2025-10-22 08:45:47 -10:00
YiYi Xu bec2d8eaea Fix: Add _skip_keys for AutoencoderKLWan (#12523)
add
2025-10-22 07:53:13 -10:00
Álvaro Somoza a0a51eb098 Kandinsky5 No cfg fix (#12527)
fix
2025-10-22 22:02:47 +05:30
Sayak Paul a5a0ccf86a [core] AutoencoderMixin to abstract common methods (#12473)
* up

* correct wording.

* up

* up

* up
2025-10-22 08:52:06 +05:30
David Bertoin dd07b19e27 Prx (#12525)
* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d6

* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d6

* make fix-copies
2025-10-21 17:09:22 -07:00
vb 57636ad4f4 purge HF_HUB_ENABLE_HF_TRANSFER; promote Xet (#12497)
* purge HF_HUB_ENABLE_HF_TRANSFER; promote Xet

* purge HF_HUB_ENABLE_HF_TRANSFER; promote Xet x2

* restrict docker build test to the ones we actually use in CI.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-10-22 00:59:20 +05:30
David Bertoin cefc2cf82d Add Photon model and pipeline support (#12456)
* Add Photon model and pipeline support

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests

* just store the T5Gemma encoder

* enhance_vae_properties if vae is provided only

* remove autocast for text encoder forwad

* BF16 example

* conditioned CFG

* remove enhance vae and use vae.config directly when possible

* move PhotonAttnProcessor2_0 in transformer_photon

* remove einops dependency and now inherits from AttentionMixin

* unify the structure of the forward block

* update doc

* update doc

* fix T5Gemma loading from hub

* fix timestep shift

* remove lora support from doc

* Rename EmbedND for PhotoEmbedND

* remove modulation dataclass

* put _attn_forward and _ffn_forward logic in PhotonBlock's forward

* renam LastLayer for FinalLayer

* remove lora related code

* rename vae_spatial_compression_ratio for vae_scale_factor

* support prompt_embeds in call

* move xattention conditionning out computation out of the denoising loop

* add negative prompts

* Use _import_structure for lazy loading

* make quality + style

* add pipeline test + corresponding fixes

* utility function that determines the default resolution given the VAE

* Refactor PhotonAttention to match Flux pattern

* built-in RMSNorm

* Revert accidental .gitignore change

* parameter names match the standard diffusers conventions

* renaming and remove unecessary attributes setting

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* quantization example

* added doc to toctree

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* use dispatch_attention_fn for multiple attention backend support

* naming changes

* make fix copy

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Add PhotonTransformer2DModel to TYPE_CHECKING imports

* make fix-copies

* Use Tuple instead of tuple

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* restrict the version of transformers

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/photon/test_pipeline_photon.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/photon/test_pipeline_photon.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* change | for Optional

* fix nits.

* use typing Dict

---------

Co-authored-by: davidb <davidb@worker-10.soperator-worker-svc.soperator.svc.cluster.local>
Co-authored-by: David Briand <david@photoroom.com>
Co-authored-by: davidb <davidb@worker-8.soperator-worker-svc.soperator.svc.cluster.local>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2025-10-21 20:55:55 +05:30
Sayak Paul b3e56e71fb styling issues. (#12522) 2025-10-21 20:04:54 +05:30
286 changed files with 28550 additions and 2573 deletions
+1 -1
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+29 -8
View File
@@ -42,18 +42,39 @@ jobs:
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
run: |
echo "$CHANGED_FILES"
for FILE in $CHANGED_FILES; do
ALLOWED_IMAGES=(
diffusers-pytorch-cpu
diffusers-pytorch-cuda
diffusers-pytorch-xformers-cuda
diffusers-pytorch-minimum-cuda
diffusers-doc-builder
)
declare -A IMAGES_TO_BUILD=()
for FILE in $CHANGED_FILES; do
# skip anything that isn't still on disk
if [[ ! -f "$FILE" ]]; then
if [[ ! -e "$FILE" ]]; then
echo "Skipping removed file $FILE"
continue
fi
if [[ "$FILE" == docker/*Dockerfile ]]; then
DOCKER_PATH="${FILE%/Dockerfile}"
DOCKER_TAG=$(basename "$DOCKER_PATH")
echo "Building Docker image for $DOCKER_TAG"
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
fi
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
IMAGES_TO_BUILD["$IMAGE"]=1
fi
done
done
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
echo "No relevant Docker changes detected."
exit 0
fi
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
DOCKER_PATH="docker/${IMAGE}"
echo "Building Docker image for $IMAGE"
docker build -t "$IMAGE" "$DOCKER_PATH"
done
if: steps.file_changes.outputs.all != ''
+22 -8
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
@@ -73,6 +73,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -84,7 +86,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -126,6 +128,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -138,7 +142,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -151,7 +155,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v --make-reports=examples_torch_cuda \
--make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -190,6 +194,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -198,7 +204,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,6 +238,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -281,6 +289,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -293,7 +303,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -358,6 +368,8 @@ jobs:
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
uv pip install pytest-reportlog
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -405,6 +417,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip install -U bitsandbytes optimum_quanto
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -531,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -587,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
+4 -3
View File
@@ -26,7 +26,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
@@ -109,7 +109,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -120,7 +121,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
+9 -7
View File
@@ -22,7 +22,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
@@ -115,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -126,7 +127,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -134,7 +135,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
-k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -246,7 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -255,11 +257,11 @@ jobs:
- name: Run fast PyTorch LoRA tests with PEFT
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_peft_main \
tests/lora/
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
+19 -16
View File
@@ -1,4 +1,4 @@
name: Fast GPU Tests on PR
name: Fast GPU Tests on PR
on:
pull_request:
@@ -24,7 +24,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
@@ -71,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -149,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
-k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
fi
- name: Failure short reports
if: ${{ failure() }}
@@ -201,7 +202,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -222,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
else
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -262,7 +264,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment
@@ -274,7 +277,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+12 -6
View File
@@ -14,7 +14,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000
@@ -76,6 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -86,7 +88,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -127,6 +129,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -139,7 +143,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -178,6 +182,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -186,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -227,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -270,7 +276,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+2 -2
View File
@@ -18,7 +18,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
@@ -70,7 +70,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
+2 -2
View File
@@ -8,7 +8,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
@@ -57,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
+6 -6
View File
@@ -84,7 +84,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -137,7 +137,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -187,7 +187,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -240,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -281,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -326,7 +326,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+1 -1
View File
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_transfer \
hf_xet \
setuptools==69.5.1 \
bitsandbytes \
torchao \
+1 -1
View File
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
scipy \
tensorboard \
transformers \
hf_transfer
hf_xet
CMD ["/bin/bash"]
+2 -3
View File
@@ -38,13 +38,12 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
datasets \
hf-doc-builder \
huggingface-hub \
hf_transfer \
hf_xet \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
transformers
CMD ["/bin/bash"]
+1 -1
View File
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_transfer
hf_xet
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
+1 -1
View File
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer
hf_xet
CMD ["/bin/bash"]
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer
hf_xet
CMD ["/bin/bash"]
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer \
hf_xet \
xformers
CMD ["/bin/bash"]
+107 -88
View File
@@ -1,5 +1,4 @@
- title: Get started
sections:
- sections:
- local: index
title: Diffusers
- local: installation
@@ -8,9 +7,8 @@
title: Quickstart
- local: stable_diffusion
title: Basic performance
- title: Pipelines
isExpanded: false
title: Get started
- isExpanded: false
sections:
- local: using-diffusers/loading
title: DiffusionPipeline
@@ -24,13 +22,14 @@
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
- local: using-diffusers/automodel
title: AutoModel
- local: using-diffusers/other-formats
title: Model formats
- local: using-diffusers/push_to_hub
title: Sharing pipelines and models
- title: Adapters
isExpanded: false
title: Pipelines
- isExpanded: false
sections:
- local: tutorials/using_peft_for_inference
title: LoRA
@@ -44,9 +43,8 @@
title: DreamBooth
- local: using-diffusers/textual_inversion_inference
title: Textual inversion
- title: Inference
isExpanded: false
title: Adapters
- isExpanded: false
sections:
- local: using-diffusers/weighted_prompts
title: Prompting
@@ -56,9 +54,8 @@
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- title: Inference optimization
isExpanded: false
title: Inference
- isExpanded: false
sections:
- local: optimization/fp16
title: Accelerate inference
@@ -70,8 +67,7 @@
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compiling and offloading quantized models
- title: Community optimizations
sections:
- sections:
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
@@ -90,9 +86,9 @@
title: ParaAttention
- local: using-diffusers/image_quality
title: FreeU
- title: Hybrid Inference
isExpanded: false
title: Community optimizations
title: Inference optimization
- isExpanded: false
sections:
- local: hybrid_inference/overview
title: Overview
@@ -102,9 +98,8 @@
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
- title: Modular Diffusers
isExpanded: false
title: Hybrid Inference
- isExpanded: false
sections:
- local: modular_diffusers/overview
title: Overview
@@ -126,9 +121,10 @@
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- title: Training
isExpanded: false
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
title: Modular Diffusers
- isExpanded: false
sections:
- local: training/overview
title: Overview
@@ -138,8 +134,7 @@
title: Adapt a model to a new task
- local: tutorials/basic_training
title: Train a diffusion model
- title: Models
sections:
- sections:
- local: training/unconditional_training
title: Unconditional image generation
- local: training/text2image
@@ -158,8 +153,8 @@
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
- title: Methods
sections:
title: Models
- sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
@@ -172,9 +167,9 @@
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
- title: Quantization
isExpanded: false
title: Methods
title: Training
- isExpanded: false
sections:
- local: quantization/overview
title: Getting started
@@ -188,9 +183,8 @@
title: quanto
- local: quantization/modelopt
title: NVIDIA ModelOpt
- title: Model accelerators and hardware
isExpanded: false
title: Quantization
- isExpanded: false
sections:
- local: optimization/onnx
title: ONNX
@@ -204,9 +198,8 @@
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
- title: Specific pipeline examples
isExpanded: false
title: Model accelerators and hardware
- isExpanded: false
sections:
- local: using-diffusers/consisid
title: ConsisID
@@ -232,12 +225,10 @@
title: Stable Video Diffusion
- local: using-diffusers/marigold_usage
title: Marigold Computer Vision
- title: Resources
isExpanded: false
title: Specific pipeline examples
- isExpanded: false
sections:
- title: Task recipes
sections:
- sections:
- local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- local: using-diffusers/conditional_image_generation
@@ -252,6 +243,7 @@
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
title: Task recipes
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: community_projects
@@ -266,12 +258,10 @@
title: Diffusers' Ethical Guidelines
- local: conceptual/evaluation
title: Evaluating Diffusion Models
- title: API
isExpanded: false
title: Resources
- isExpanded: false
sections:
- title: Main Classes
sections:
- sections:
- local: api/configuration
title: Configuration
- local: api/logging
@@ -282,8 +272,8 @@
title: Quantization
- local: api/parallel
title: Parallel inference
- title: Modular
sections:
title: Main Classes
- sections:
- local: api/modular_diffusers/pipeline
title: Pipeline
- local: api/modular_diffusers/pipeline_blocks
@@ -294,8 +284,8 @@
title: Components and configs
- local: api/modular_diffusers/guiders
title: Guiders
- title: Loaders
sections:
title: Modular
- sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
- local: api/loaders/lora
@@ -310,14 +300,13 @@
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
- title: Models
sections:
title: Loaders
- sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- title: ControlNets
sections:
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
@@ -332,16 +321,20 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- title: Transformers
sections:
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/chronoedit_transformer_3d
title: ChronoEditTransformer3DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
@@ -362,6 +355,8 @@
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuanimage_transformer_2d
title: HunyuanImageTransformer2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
@@ -384,6 +379,8 @@
title: QwenImageTransformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/sana_video_transformer3d
title: SanaVideoTransformer3DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
@@ -394,10 +391,12 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_animate_transformer_3d
title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- title: UNets
sections:
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
@@ -412,8 +411,8 @@
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
- title: VAEs
sections:
title: UNets
- sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -426,6 +425,10 @@
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
@@ -446,14 +449,26 @@
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
- title: Pipelines
sections:
title: VAEs
title: Models
- sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- title: Image
sections:
- sections:
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/audioldm2
title: AudioLDM 2
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/stable_audio
title: Stable Audio
title: Audio
- sections:
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
@@ -466,6 +481,8 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -514,6 +531,8 @@
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
@@ -548,12 +567,16 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/sana_video
title: Sana Video
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
@@ -562,8 +585,7 @@
title: Shap-E
- local: api/pipelines/stable_cascade
title: Stable Cascade
- title: Stable Diffusion
sections:
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
@@ -581,7 +603,8 @@
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D
Upscaler
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
title: Safe Stable Diffusion
- local: api/pipelines/stable_diffusion/sdxl_turbo
@@ -598,6 +621,7 @@
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/unclip
@@ -610,10 +634,12 @@
title: VisualCloze
- local: api/pipelines/wuerstchen
title: Wuerstchen
- title: Video
sections:
title: Image
- sections:
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/chronoedit
title: ChronoEdit
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consisid
@@ -624,6 +650,8 @@
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/kandinsky5_video
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx_video
@@ -642,20 +670,9 @@
title: Text2Video-Zero
- local: api/pipelines/wan
title: Wan
- title: Audio
sections:
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/audioldm2
title: AudioLDM 2
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/stable_audio
title: Stable Audio
- title: Schedulers
sections:
title: Video
title: Pipelines
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/cm_stochastic_iterative
@@ -724,8 +741,8 @@
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- title: Internal classes
sections:
title: Schedulers
- sections:
- local: api/internal_classes_overview
title: Overview
- local: api/attnprocessor
@@ -741,4 +758,6 @@
- local: api/image_processor
title: VAE Image Processor
- local: api/video_processor
title: Video Processor
title: Video Processor
title: Internal classes
title: API
+1 -9
View File
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
# AutoModel
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
```python
from diffusers import AutoModel, AutoPipelineForText2Image
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
```
[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
## AutoModel
@@ -0,0 +1,32 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImage
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImage
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImage
[[autodoc]] AutoencoderKLHunyuanImage
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,32 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImageRefiner
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImageRefiner
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImageRefiner
[[autodoc]] AutoencoderKLHunyuanImageRefiner
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
## ChromaTransformer2DModel
@@ -0,0 +1,32 @@
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# ChronoEditTransformer3DModel
A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
The model can be loaded with the following code snippet.
```python
from diffusers import ChronoEditTransformer3DModel
transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## ChronoEditTransformer3DModel
[[autodoc]] ChronoEditTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanImageTransformer2DModel
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanImageTransformer2DModel
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanImageTransformer2DModel
[[autodoc]] HunyuanImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,36 @@
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SanaVideoTransformer3DModel
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
The abstract from the paper is:
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
The model can be loaded with the following code snippet.
```python
from diffusers import SanaVideoTransformer3DModel
import torch
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SanaVideoTransformer3DModel
[[autodoc]] SanaVideoTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel
@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanAnimateTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanAnimateTransformer3DModel
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanAnimateTransformer3DModel
[[autodoc]] WanAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+45
View File
@@ -0,0 +1,45 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__
+7 -6
View File
@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
> [!TIP]
> Chroma can use all the same optimizations as Flux.
## Inference
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = [
@@ -63,10 +64,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
model_id = "lodestones/Chroma"
model_id = "lodestones/Chroma1-HD"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
+156
View File
@@ -0,0 +1,156 @@
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# ChronoEdit
[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
### Image Editing
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image(
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cups liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
"The mouses pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacups floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=50,
guidance_scale=5.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
Optionally, enable **temporal reasoning** for improved physical consistency:
```py
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=29,
num_inference_steps=50,
guidance_scale=5.0,
enable_temporal_reasoning=True,
num_temporal_reasoning_steps=50,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
### Inference with 8-Step Distillation Lora
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
image = load_image(
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cups liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
"The mouses pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacups floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=8,
guidance_scale=1.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
## ChronoEditPipeline
[[autodoc]] ChronoEditPipeline
- all
- __call__
## ChronoEditPipelineOutput
[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
@@ -0,0 +1,152 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanImage2.1
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
HunyuanImage-2.1 comes in the following variants:
| model type | model id |
|:----------:|:--------:|
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
```
You can inspect the `guider` object:
```py
>>> pipe.guider
AdaptiveProjectedMixGuidance {
"_class_name": "AdaptiveProjectedMixGuidance",
"_diffusers_version": "0.36.0.dev0",
"adaptive_projected_guidance_momentum": -0.5,
"adaptive_projected_guidance_rescale": 10.0,
"adaptive_projected_guidance_scale": 10.0,
"adaptive_projected_guidance_start_step": 5,
"enabled": true,
"eta": 0.0,
"guidance_rescale": 0.0,
"guidance_scale": 3.5,
"start": 0.0,
"stop": 1.0,
"use_original_formulation": false
}
State:
step: None
num_inference_steps: None
timestep: None
count_prepared: 0
enabled: True
num_conditions: 2
momentum_buffer: None
is_apg_enabled: False
is_cfg_enabled: True
```
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
# Update the guider configuration
pipe.guider = pipe.guider.new(guidance_scale=5.0)
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
image = pipe(
prompt=prompt,
num_inference_steps=50,
height=2048,
width=2048,
).images[0]
image.save("image.png")
```
## HunyuanImage-2.1-Distilled
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
out = pipe(
prompt,
num_inference_steps=8,
distilled_guidance_scale=3.25,
height=2048,
width=2048,
generator=generator,
).images[0]
```
## HunyuanImagePipeline
[[autodoc]] HunyuanImagePipeline
- all
- __call__
## HunyuanImageRefinerPipeline
[[autodoc]] HunyuanImageRefinerPipeline
- all
- __call__
## HunyuanImagePipelineOutput
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
@@ -0,0 +1,149 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0 Video
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
The model introduces several key innovations:
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
> [!TIP]
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
## Available Models
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
| model_id | Description | Use Cases |
|------------|-------------|-----------|
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
All models are available in 5-second and 10-second video generation versions.
## Kandinsky5T2VPipeline
[[autodoc]] Kandinsky5T2VPipeline
- all
- __call__
## Usage Examples
### Basic Text-to-Video Generation
```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### 10 second Models
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
```python
pipe = Kandinsky5T2VPipeline.from_pretrained(
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Sett attention bakend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
) # <--- Compile with max-autotune-no-cudagraphs
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=241,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
output = pipe(
prompt="A beautiful sunset over mountains",
num_inference_steps=16, # <--- Model is distilled in 16 steps
guidance_scale=1.0, # <--- no CFG
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
## Citation
```bibtex
@misc{kandinsky2025,
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
year = 2025
}
```
+131
View File
@@ -0,0 +1,131 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# PRX
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")
```
### Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")
```
## Memory Optimization
For memory-constrained environments:
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PRXPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PRXPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
@@ -24,9 +24,6 @@ The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
+188
View File
@@ -0,0 +1,188 @@
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Sana-Video
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
The abstract from the paper is:
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Generation Pipelines
<hfoptions id="generation pipelines">`
<hfoption id="Text-to-Video">
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id =
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.float32)
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana_video.mp4", fps=16)
```
</hfoption>
<hfoption id="Image-to-Video">
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
pipe = SanaImageToVideoPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
motion_scale = 30.0
video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana-i2v.mp4", fps=16)
```
</hfoption>
</hfoptions>
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = SanaVideoPipeline.from_pretrained(
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
model_score = 30
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_prompt = f" motion score: {model_score}."
prompt = prompt + motion_prompt
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=6.0,
num_inference_steps=50
).frames[0]
export_to_video(output, "sana-video-output.mp4", fps=16)
```
## SanaVideoPipeline
[[autodoc]] SanaVideoPipeline
- all
- __call__
## SanaImageToVideoPipeline
[[autodoc]] SanaImageToVideoPipeline
- all
- __call__
## SanaVideoPipelineOutput
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
+226 -17
View File
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -249,6 +250,208 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
The project page: https://humanaigc.github.io/wan-animate
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
#### Usage
The Wan-Animate pipeline supports two modes of operation:
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
##### Prerequisites
Before using the pipeline, you need to preprocess your reference video to extract:
- **Pose video**: Contains skeletal keypoints representing body motion
- **Face video**: Contains facial feature representations for expression control
For replacement mode, you additionally need:
- **Background video**: The original video containing the scene
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
> [!NOTE]
> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
The example below demonstrates how to use the Wan-Animate pipeline:
<hfoptions id="Animate usage">
<hfoption id="Animation mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load character image and preprocessed videos
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
# Resize image to match VAE constraints
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
# Generate animated video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
guidance_scale=1.0,
mode="animate", # Animation mode (default)
).frames[0]
export_to_video(output, "animated_character.mp4", fps=30)
```
</hfoption>
<hfoption id="Replacement mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load all required inputs for replacement mode
image = load_image("path/to/new_character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
background_video = load_video("path/to/background_video.mp4") # Original scene
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
# Resize image to match video dimensions
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
# Replace character in background video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_lengths=77,
guidance_scale=1.0,
mode="replace", # Replacement mode
).frames[0]
export_to_video(output, "character_replaced.mp4", fps=30)
```
</hfoption>
<hfoption id="Advanced options">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4")
face_video = load_video("path/to/face_video.mp4")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio"
negative_prompt = "blurry, low quality"
# Advanced: Use temporal guidance and custom callback
def callback_fn(pipe, step_index, timestep, callback_kwargs):
# You can modify latents or other tensors here
print(f"Step {step_index}, Timestep {timestep}")
return callback_kwargs
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
num_inference_steps=50,
guidance_scale=5.0,
prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
callback_on_step_end=callback_fn,
callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
export_to_video(output, "animated_advanced.mp4", fps=30)
```
</hfoption>
</hfoptions>
#### Key Parameters
- **mode**: Choose between `"animate"` (default) or `"replace"`
- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
@@ -281,10 +484,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
@@ -359,6 +562,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
## WanAnimatePipeline
[[autodoc]] WanAnimatePipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
@@ -0,0 +1,492 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Building Custom Blocks
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
> [!TIP]
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
## Project Structure
Your custom block project should use the following structure:
```shell
.
├── block.py
└── modular_config.json
```
- `block.py` contains the custom block implementation
- `modular_config.json` contains the metadata needed to load the block
## Example: Florence 2 Inpainting Block
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
```py
# Inside block.py
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ComponentSpec,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
```
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
```
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
```py
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
def get_annotations(self, components, images, prompts, task):
task_prompts = [task + prompt for prompt in prompts]
inputs = components.image_annotator_processor(
text=task_prompts, images=images, return_tensors="pt"
).to(components.image_annotator.device, components.image_annotator.dtype)
generated_ids = components.image_annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = components.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
outputs = []
for image, annotation in zip(images, annotations):
outputs.append(
components.image_annotator_processor.post_process_generation(
annotation, task=task, image_size=(image.width, image.height)
)
)
return outputs
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
masks = []
for image, annotation in zip(images, annotations):
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
for _, _annotation in annotation.items():
if "polygons" in _annotation:
for polygon in _annotation["polygons"]:
polygon = np.array(polygon).reshape(-1, 2)
if len(polygon) < 3:
continue
polygon = polygon.reshape(-1).tolist()
draw.polygon(polygon, fill=fill)
elif "bbox" in _annotation:
bbox = _annotation["bbox"]
draw.rectangle(bbox, fill="white")
masks.append(mask_image)
return masks
def prepare_bounding_boxes(self, images, annotations):
outputs = []
for image, annotation in zip(images, annotations):
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for _, _annotation in annotation.items():
bbox = _annotation["bbox"]
label = _annotation["label"]
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
outputs.append(image_copy)
return outputs
def prepare_inputs(self, images, prompts):
prompts = prompts or ""
if isinstance(images, Image.Image):
images = [images]
if isinstance(prompts, str):
prompts = [prompts]
if len(images) != len(prompts):
raise ValueError("Number of images and annotation prompts must match.")
return images, prompts
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images, annotation_task_prompt = self.prepare_inputs(
block_state.image, block_state.annotation_prompt
)
task = block_state.annotation_task
fill = block_state.fill
annotations = self.get_annotations(
components, images, annotation_task_prompt, task
)
block_state.annotations = annotations
if block_state.annotation_output_type == "mask_image":
block_state.mask_image = self.prepare_mask(images, annotations)
else:
block_state.mask_image = None
if block_state.annotation_output_type == "mask_overlay":
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
elif block_state.annotation_output_type == "bounding_box":
block_state.image = self.prepare_bounding_boxes(images, annotations)
self.set_block_state(state, block_state)
return components, state
```
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
<hfoptions id="share">
<hfoption id="hf CLI">
```shell
# In the folder with the `block.py` file, run:
diffusers-cli custom_block
```
Then upload the block to the Hub:
```shell
hf upload <your repo id> . .
```
</hfoption>
<hfoption id="push_to_hub">
```py
from block import Florence2ImageAnnotatorBlock
block = Florence2ImageAnnotatorBlock()
block.push_to_hub("<your repo id>")
```
</hfoption>
</hfoptions>
## Using Custom Blocks
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
my_blocks = INPAINT_BLOCKS.copy()
# insert the annotation block before the image encoding step
my_blocks.insert("image_annotator", image_annotator_block, 1)
# Create our initial set of inpainting blocks
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
pipe = blocks.init_pipeline(repo_id)
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
image = image.resize((1024, 1024))
prompt = ["A red car"]
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
annotation_prompt = ["the car"]
output = pipe(
prompt=prompt,
image=image,
annotation_task=annotation_task,
annotation_prompt=annotation_prompt,
annotation_output_type="mask_image",
num_inference_steps=35,
guidance_scale=7.5,
strength=0.95,
output="images"
)
output[0].save("florence-inpainting.png")
```
## Editing Custom Blocks
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
```
Any changes made to the block files in this folder will be reflected when you load the block again.
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
- `__call__` method defines the loop structure and iteration logic.
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
```
```
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
]
```
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
Use `InputParam` to define `intermediate_inputs`.
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
@@ -76,7 +66,7 @@ def __call__(self, components, state):
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
@@ -112,4 +102,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
```
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
<hfoptions id="sequential">
<hfoption id="InputBlock">
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
```py
print(blocks)
print(blocks.doc)
```
```
@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
@@ -0,0 +1,46 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoModel
The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
```py
import torch
from diffusers import AutoModel, DiffusionPipeline
transformer = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
)
text_encoder = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
)
```
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
+1 -1
View File
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA")
image = image.convert("RGB")
+11 -6
View File
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -490,7 +490,7 @@ class RegionalPromptingStableDiffusionPipeline(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -841,7 +841,7 @@ class RegionalPromptingStableDiffusionPipeline(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
@@ -872,7 +872,7 @@ class RegionalPromptingStableDiffusionPipeline(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1062,7 +1062,7 @@ class RegionalPromptingStableDiffusionPipeline(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
+1 -2
View File
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
Below, we explain both in more detail.
#### Provide the dataset as a folder
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape)
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
"""
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
"""
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
channels = tensor.shape[0]
if channels == 3:
return tensor
if channels == 1:
return tensor.repeat(3, 1, 1)
if channels == 2:
return torch.cat([tensor, tensor[:1]], dim=0)
if channels > 3:
return tensor[:3]
raise ValueError(f"Unsupported number of channels: {channels}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--preserve_input_precision",
action="store_true",
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
spatial_augmentations = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]
augmentations = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
spatial_augmentations
+ [
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
precision_augmentations = transforms.Compose(
[
transforms.PILToTensor(),
transforms.Lambda(_ensure_three_channels),
transforms.ConvertImageDtype(torch.float32),
]
+ spatial_augmentations
+ [transforms.Normalize([0.5], [0.5])]
)
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
processed = []
for image in examples["image"]:
if not args.preserve_input_precision:
processed.append(augmentations(image.convert("RGB")))
else:
precise_image = image
if precise_image.mode == "P":
precise_image = precise_image.convert("RGB")
processed.append(precision_augmentations(precise_image))
return {"input": processed}
logger.info(f"Dataset size: {len(dataset)}")
File diff suppressed because it is too large Load Diff
+345
View File
@@ -0,0 +1,345 @@
#!/usr/bin/env python3
"""
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
num_heads: int = 28
depth: int = 16
axes_dim: Tuple[int, int] = (32, 32)
theta: int = 10_000
time_factor: float = 1000.0
time_max_period: int = 10_000
@dataclass(frozen=True)
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
config_dict = asdict(cfg)
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
return config_dict
def create_parameter_mapping(depth: int) -> dict:
"""Create mapping from old parameter names to new diffusers names."""
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
# Attention output projection
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
return mapping
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
"""Convert old checkpoint parameters to new diffusers format."""
print("Converting checkpoint parameters...")
mapping = create_parameter_mapping(depth)
converted_state_dict = {}
for key, value in old_state_dict.items():
new_key = key
# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")
converted_state_dict[new_key] = value
print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
# Convert parameter names if needed
model_depth = int(config.get("depth", 16))
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")
return transformer
def create_scheduler_config(output_path: str, shift: float):
"""Create FlowMatchEulerDiscreteScheduler config."""
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
scheduler_path = os.path.join(output_path, "scheduler")
os.makedirs(scheduler_path, exist_ok=True)
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
json.dump(scheduler_config, f, indent=2)
print("✓ Created scheduler config")
def download_and_save_vae(vae_type: str, output_path: str):
"""Download and save VAE to local directory."""
from diffusers import AutoencoderDC, AutoencoderKL
vae_path = os.path.join(output_path, "vae")
os.makedirs(vae_path, exist_ok=True)
if vae_type == "flux":
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
else: # dc-ae
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
vae.save_pretrained(vae_path)
print(f"✓ Saved VAE to {vae_path}")
def download_and_save_text_encoder(output_path: str):
"""Download and save T5Gemma text encoder and tokenizer."""
from transformers import GemmaTokenizerFast
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
text_encoder_path = os.path.join(output_path, "text_encoder")
tokenizer_path = os.path.join(output_path, "tokenizer")
os.makedirs(text_encoder_path, exist_ok=True)
os.makedirs(tokenizer_path, exist_ok=True)
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
# Extract and save only the encoder
t5gemma_encoder = t5gemma_model.encoder
t5gemma_encoder.save_pretrained(text_encoder_path)
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
tokenizer.save_pretrained(tokenizer_path)
print(f"✓ Saved tokenizer to {tokenizer_path}")
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
"""Create model_index.json for the pipeline."""
if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)
def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
config = build_config(args.vae_type)
# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")
# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)
# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")
# Create scheduler config
create_scheduler_config(args.output_path, args.shift)
download_and_save_vae(args.vae_type, args.output_path)
download_and_save_text_encoder(args.output_path)
# Create model_index.json
create_model_index(args.vae_type, args.resolution, args.output_path)
# Verify the pipeline can be loaded
try:
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")
except Exception as e:
print(f"Pipeline verification failed: {e}")
return False
print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)
parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)
parser.add_argument(
"--resolution",
type=int,
choices=[256, 512, 1024],
default=DEFAULT_RESOLUTION,
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
)
parser.add_argument(
"--shift",
type=float,
default=3.0,
help="Shift for the scheduler",
)
args = parser.parse_args()
try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
from diffusers import (
SanaControlNetModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+327
View File
@@ -0,0 +1,327 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import os
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaVideoPipeline,
SanaVideoTransformer3DModel,
UniPCMultistepScheduler,
)
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
def main(args):
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
snapshot_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
cache_dir=cache_dir_path,
repo_type="model",
)
file_path = hf_hub_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
cache_dir=cache_dir_path,
repo_type="model",
)
else:
file_path = args.orig_ckpt_path
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
all_state_dict = torch.load(file_path, weights_only=True)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 8.0
if args.task == "i2v":
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
# Positional embedding interpolation scale.
qk_norm = True
# sample size
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
patch_size = (1, 1, 1)
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Linear Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.point_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.t_conv.weight"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"caption_channels": 2304,
"mlp_ratio": 3.0,
"attention_bias": False,
"sample_size": sample_size,
"patch_size": patch_size,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 1024,
}
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
transformer = transformer.to(weight_dtype)
if not args.save_full_pipeline:
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole Pipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Choose the appropriate pipeline and scheduler based on model type
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
elif args.scheduler_type == "uni-pc":
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction",
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=flow_shift,
)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaVideoPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--video_size",
default=480,
type=int,
choices=[480, 720],
required=False,
help="Video size of pretrained model, 480 or 720.",
)
parser.add_argument(
"--model_type",
default="SanaVideo",
type=str,
choices=[
"SanaVideo",
],
)
parser.add_argument(
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
main(args)
+1 -1
View File
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
+1 -1
View File
@@ -18,7 +18,7 @@ from diffusers import (
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+1 -1
View File
@@ -20,7 +20,7 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
+265 -6
View File
@@ -6,11 +6,20 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
CLIPVisionModelWithProjection,
UMT5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"after_proj": "proj_out",
}
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"cross_attn.k_img": "attn2.to_k_img",
"cross_attn.v_img": "attn2.to_v_img",
"cross_attn.norm_k_img": "attn2.norm_k_img",
# After cross_attn -> attn2 rename, we need to rename the img keys
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
# Motion encoder mappings
# The name mapping is complicated for the convolutional part so we handle that in its own function
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
"face_encoder.conv2.conv": "face_encoder.conv2",
"face_encoder.conv3.conv": "face_encoder.conv3",
# Face adapter mappings are handled in a separate function
}
# TODO: Verify this and simplify if possible.
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
"""
Convert all motion encoder weights for Animate model.
In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
Conversion strategy:
1. Drop .kernel buffers (blur kernels)
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
"""
# Skip if not a weight, bias, or kernel
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
return
# Handle Blur kernel buffers from original implementation.
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
if ".kernel" in key and "motion_encoder" in key:
# Remove unexpected blur kernel buffers to avoid strict load errors
state_dict.pop(key, None)
return
# Rename Sequential indices to named components in ConvLayer and ResBlock
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
parts = key.split(".")
# Find the sequential index (digit) after convs or after conv1/conv2/skip
# Examples:
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
# - enc.net_app.convs.8 -> conv_out (final conv layer)
convs_idx = parts.index("convs") if "convs" in parts else -1
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
bias = False
# The nn.Sequential index will always follow convs
sequential_idx = int(parts[convs_idx + 1])
if sequential_idx == 0:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_in.weight"
elif key.endswith(".bias"):
new_key = "motion_encoder.conv_in.act_fn.bias"
bias = True
elif sequential_idx == final_conv_idx:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_out.weight"
else:
# Intermediate .convs. layers, which get mapped to .res_blocks.
prefix = "motion_encoder.res_blocks."
layer_name = parts[convs_idx + 2]
if layer_name == "skip":
layer_name = "conv_skip"
if key.endswith(".weight"):
param_name = "weight"
elif key.endswith(".bias"):
param_name = "act_fn.bias"
bias = True
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
param = state_dict.pop(key)
if bias:
param = param.squeeze()
state_dict[new_key] = param
return
return
return
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert face adapter weights for the Animate model.
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
"""
# Skip if not a weight or bias
if ".weight" not in key and ".bias" not in key:
return
prefix = "face_adapter."
if ".fuser_blocks." in key:
parts = key.split(".")
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
block_idx = parts[module_list_idx + 1]
layer_name = parts[module_list_idx + 2]
param_name = parts[module_list_idx + 3]
if layer_name == "linear1_kv":
layer_name_k = "to_k"
layer_name_v = "to_v"
suffix_k = ".".join([block_idx, layer_name_k, param_name])
suffix_v = ".".join([block_idx, layer_name_v, param_name])
new_key_k = prefix + suffix_k
new_key_v = prefix + suffix_v
kv_proj = state_dict.pop(key)
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
state_dict[new_key_k] = k_proj
state_dict[new_key_v] = v_proj
return
else:
if layer_name == "q_norm":
new_layer_name = "norm_q"
elif layer_name == "k_norm":
new_layer_name = "norm_k"
elif layer_name == "linear1_q":
new_layer_name = "to_q"
elif layer_name == "linear2":
new_layer_name = "to_out"
suffix_parts = [block_idx, new_layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
state_dict[new_key] = state_dict.pop(key)
return
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"motion_encoder": convert_animate_motion_encoder_weights,
"face_adapter": convert_animate_face_adapter_weights,
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": None,
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
"motion_style_dim": 512,
"motion_dim": 20,
"motion_encoder_dim": 512,
"face_encoder_hidden_dim": 1024,
"face_encoder_num_heads": 4,
"inject_face_latents_blocks": 5,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
# Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
# Move to CPU to ensure all tensors are materialized
transformer = transformer.to("cpu")
return transformer
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
+39
View File
@@ -149,7 +149,9 @@ else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AdaptiveProjectedMixGuidance",
"AutoGuidance",
"BaseGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
@@ -184,6 +186,8 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -194,9 +198,11 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"ChronoEditTransformer3DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
@@ -216,6 +222,7 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
@@ -234,11 +241,13 @@ else:
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
"SanaVideoTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -259,6 +268,7 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
@@ -398,6 +408,7 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
]
@@ -424,9 +435,11 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
@@ -461,6 +474,8 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
@@ -519,6 +534,7 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
@@ -529,10 +545,13 @@ else:
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -620,6 +639,7 @@ else:
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
@@ -847,7 +867,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .guiders import (
AdaptiveProjectedGuidance,
AdaptiveProjectedMixGuidance,
AutoGuidance,
BaseGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
@@ -878,6 +900,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -888,9 +912,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -910,6 +936,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
@@ -928,11 +955,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -952,6 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
@@ -1066,6 +1096,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
)
@@ -1088,9 +1119,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
ChronoEditPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
@@ -1125,6 +1158,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanImagePipeline,
HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
@@ -1183,6 +1218,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
@@ -1193,10 +1229,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
SanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -1283,6 +1321,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
+3 -13
View File
@@ -14,28 +14,18 @@
from typing import Union
from ..utils import is_torch_available
from ..utils import is_torch_available, logging
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
@@ -76,19 +77,27 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -152,6 +161,44 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def normalized_guidance(
pred_cond: torch.Tensor,
@@ -0,0 +1,297 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedMixGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
used, and momentum buffer is updated).
enabled (`bool`, defaults to `True`):
Whether this guidance is enabled.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 3.5,
guidance_rescale: float = 0.0,
adaptive_projected_guidance_scale: float = 10.0,
adaptive_projected_guidance_momentum: float = -0.5,
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# no guidance
if not self._is_cfg_enabled():
pred = pred_cond
# CFG + update momentum buffer
elif not self._is_apg_enabled():
if self.momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
# CFG + update momentum buffer
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
# APG
elif self._is_apg_enabled():
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.adaptive_projected_guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled() or self._is_cfg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
if not self._is_cfg_enabled():
return False
is_within_range = False
if self._step is not None:
is_within_range = self._step > self.adaptive_projected_guidance_start_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
else:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
return is_within_range and not is_close
def get_state(self):
state = super().get_state()
state["momentum_buffer"] = self.momentum_buffer
state["is_apg_enabled"] = self._is_apg_enabled()
state["is_cfg_enabled"] = self._is_cfg_enabled()
return state
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def update_momentum_buffer(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
momentum_buffer: Optional[MomentumBuffer] = None,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
diff = momentum_buffer.running_average
else:
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
+15 -9
View File
@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
@@ -132,16 +133,21 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
Implements Classifier-Free Guidance (CFG) for diffusion models.
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Reference: https://huggingface.co/papers/2207.12598
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
unconditional data, then combining predictions during inference. This allows trading off between quality (high
guidance) and diversity (low guidance).
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
**Two CFG Formulations:**
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
1. **Original formulation** (from paper):
```
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
```
Moves conditional predictions further from unconditional ones.
2. **Diffusers-native formulation** (default, from Imagen paper):
```
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
```
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
quality", "watermarks"). Equivalent in theory but more intuitive.
Use `use_original_formulation=True` to switch to the original formulation.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
may reduce quality. Typical range: 1.0-20.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
to 1.0 (full rescaling).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
diffusers-native formulation from the Imagen paper.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
steps.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
steps.
enabled (`bool`, defaults to `True`):
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -76,23 +83,29 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -68,31 +68,41 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if self._step < self.zero_init_steps:
# YiYi Notes: add default behavior for self._enabled == False
if not self._enabled:
pred = pred_cond
elif self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
enabled: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop)
super().__init__(min_start, max_stop, enabled)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
@@ -217,16 +218,21 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
+124 -41
View File
@@ -40,7 +40,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
self._start = start
self._stop = stop
self._step: int = None
@@ -48,7 +52,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
self._enabled = enabled
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
@@ -60,6 +64,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def new(self, **kwargs):
"""
Creates a copy of this guider instance, optionally with modified configuration parameters.
Args:
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
returns an exact copy with the same configuration.
Returns:
A new guider instance with the same (or updated) configuration.
Example:
```python
# Create a CFG guider
guider = ClassifierFreeGuidance(guidance_scale=3.5)
# Create an exact copy
same_guider = guider.new()
# Create a copy with different start step, keeping other config the same
new_guider = guider.new(guidance_scale=5)
```
"""
return self.__class__.from_config(self.config, **kwargs)
def disable(self):
self._enabled = False
@@ -72,42 +101,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
def get_state(self) -> Dict[str, Any]:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
`Dict[str, Any]`: A dictionary containing the current state variables including:
- step: Current inference step
- num_inference_steps: Total number of inference steps
- timestep: Current timestep tensor
- count_prepared: Number of times prepare_models has been called
- enabled: Whether the guidance is enabled
- num_conditions: Number of conditions
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
state = {
"step": self._step,
"num_inference_steps": self._num_inference_steps,
"timestep": self._timestep,
"count_prepared": self._count_prepared,
"enabled": self._enabled,
"num_conditions": self.num_conditions,
}
return state
def __repr__(self) -> str:
"""
Returns a string representation of the guidance object including both config and current state.
"""
# Get ConfigMixin's __repr__
str_repr = super().__repr__()
# Get current state
state = self.get_state()
# Format each state variable on its own line with indentation
state_lines = []
for k, v in state.items():
# Convert value to string and handle multi-line values
v_str = str(v)
if "\n" in v_str:
# For multi-line values (like MomentumBuffer), indent subsequent lines
v_lines = v_str.split("\n")
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
state_lines.append(f" {k}: {v_str}")
state_str = "\n".join(state_lines)
return f"{str_repr}\nState:\n{state_str}"
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
@@ -127,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
@@ -154,6 +198,49 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
data_batch = {}
for key, value in data.items():
try:
if isinstance(value, torch.Tensor):
data_batch[key] = value
elif isinstance(value, tuple):
data_batch[key] = value[tuple_index]
else:
raise ValueError(f"Invalid value type: {type(value)}")
except ValueError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
def _prepare_batch_from_block_state(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
@@ -182,10 +269,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in input_fields.items():
try:
@@ -290,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -186,8 +182,28 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
+25 -9
View File
@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -182,8 +178,28 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -171,8 +167,28 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -58,23 +58,29 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
+30
View File
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
),
)
# HunyuanImageAttnProcessor
AttentionProcessorRegistry.register(
model_class=HunyuanImageAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_hunyuanimage import (
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
),
)
# HunyuanImage2.1
TransformerBlockRegistry.register(
model_class=HunyuanImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanImageSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
+5 -3
View File
@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook):
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
return x
else:
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook):
+34 -11
View File
@@ -409,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
@@ -460,7 +460,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res
@@ -1045,16 +1045,39 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
Args:
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.
Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
# 1. Cast the tensor to a larger integer type (e.g., int32)
# to safely perform the multiplication by 256.
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
# 3. Cast the final result to the desired depth map type (uint16) if needed
# before returning, though leaving it as int32/int64 is often safer
# for return value from a library function.
if isinstance(image, torch.Tensor):
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
original_dtype = image.dtype
image_safe = image.to(torch.int32)
# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
# You may want to cast the final result to uint16, but casting to a
# larger int type (like int32) is sufficient to fix the overflow.
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
return depth_map.to(original_dtype)
elif isinstance(image, np.ndarray):
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
original_dtype = image.dtype
image_safe = image.astype(np.int32)
# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
return depth_map.astype(original_dtype)
else:
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
+29 -5
View File
@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
@@ -2193,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
state_dict = {convert_key(k): v for k, v in state_dict.items()}
has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
+2 -1
View File
@@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
has_default = any("default." in k for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
+16
View File
@@ -36,6 +36,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -82,7 +84,9 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -91,16 +95,20 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -132,6 +140,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -168,8 +178,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -181,6 +193,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
@@ -192,14 +205,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
+47 -25
View File
@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -79,29 +84,46 @@ class ContextParallelConfig:
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
@@ -119,7 +141,7 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
@@ -127,14 +149,14 @@ class ParallelConfig:
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)
+234 -52
View File
@@ -16,6 +16,7 @@ import contextlib
import functools
import inspect
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@@ -27,6 +28,8 @@ if torch.distributed.is_available():
from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
@@ -40,13 +43,14 @@ from ..utils import (
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
@@ -54,6 +58,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -78,17 +83,10 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
flash_attn_3_func_hub = None
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -178,6 +176,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# `aiter`
AITER = "aiter"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
@@ -207,7 +208,7 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_supports_context_parallel = {}
_supports_context_parallel = set()
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@@ -224,7 +225,9 @@ class _AttentionBackendRegistry:
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
if supports_context_parallel:
cls._supports_context_parallel.add(backend.value)
return func
return decorator
@@ -238,15 +241,31 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
@classmethod
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
def _is_context_parallel_available(
cls,
backend: AttentionBackendName,
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
supports_context_parallel = backend.value in cls._supports_context_parallel
return supports_context_parallel
@dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""
repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
@contextlib.contextmanager
@@ -293,14 +312,6 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = {
"query": query,
"key": key,
@@ -379,12 +390,18 @@ def _check_shape(
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
# Expected shapes:
# query: (batch_size, seq_len_q, num_heads, head_dim)
# key: (batch_size, seq_len_kv, num_heads, head_dim)
# value: (batch_size, seq_len_kv, num_heads, head_dim)
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
if query.shape[-1] != key.shape[-1]:
raise ValueError("Query and key must have the same last dimension.")
if query.shape[-2] != value.shape[-2]:
raise ValueError("Query and value must have the same second to last dimension.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
raise ValueError("Attention mask must match the key's second to last dimension.")
raise ValueError("Query and key must have the same head dimension.")
if key.shape[-3] != value.shape[-3]:
raise ValueError("Key and value must have the same sequence length.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
raise ValueError("Attention mask must match the key's sequence length.")
# ===== Helper functions =====
@@ -405,13 +422,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [
@@ -555,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
# ===== Helpers for downloading kernels =====
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
raise
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -630,6 +672,86 @@ def _(
# ===== Helper functions to use attention backends with templated CP autograd functions =====
def _native_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# Native attention does not return_lse
if return_lse:
raise ValueError("Native attention does not support return_lse=True")
# used for backward pass
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.attn_mask = attn_mask
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.enable_gqa = enable_gqa
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
def _native_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value = ctx.saved_tensors
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)
grad_value = grad_value_t.permute(0, 2, 1, 3)
return grad_query, grad_key, grad_value
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1322,7 +1444,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
@@ -1397,6 +1520,47 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
@@ -1463,6 +1627,7 @@ def _native_flex_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
supports_context_parallel=True,
)
def _native_attention(
query: torch.Tensor,
@@ -1478,18 +1643,35 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op=_native_attention_forward_op,
backward_op=_native_attention_backward_op,
_parallel_config=_parallel_config,
)
return out
+1 -3
View File
@@ -147,14 +147,13 @@ class AutoModel(ConfigMixin):
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
library = None
orig_class_name = None
@@ -205,7 +204,6 @@ class AutoModel(ConfigMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
else:
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
@@ -107,9 +107,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
class ResBlock(nn.Module):
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: Tuple[int] = (),
qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
return hidden_states
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
@@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
@@ -32,10 +32,10 @@ from ..attention_processor import (
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module):
@@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module):
return sample
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_size - self.tile_overlap_w,
)
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
@@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -24,7 +24,7 @@ from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, IdentityDistribution
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
logger = get_logger(__name__)
@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
@@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
@@ -26,7 +26,7 @@ from ..activations import get_activation
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -624,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
@@ -763,27 +763,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -0,0 +1,709 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageResnetBlock(nn.Module):
r"""
Residual block with two convolutions and optional channel change.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x):
# Apply shortcut connection
residual = x
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
# Add residual connection
return x + residual
class HunyuanImageAttentionBlock(nn.Module):
r"""
Self-attention with a single head.
Args:
in_channels (int): The number of channels in the input tensor.
"""
def __init__(self, in_channels: int):
super().__init__()
# layers
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
self.proj = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
identity = x
x = self.norm(x)
# compute query, key, value
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, height, width = query.shape
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
# apply attention
x = F.scaled_dot_product_attention(query, key, value)
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
# output projection
x = self.proj(x)
return x + identity
class HunyuanImageDownsample(nn.Module):
"""
Downsampling block for spatial reduction.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
if out_channels % factor != 0:
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.group_size = factor * in_channels // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
h = h.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = x.shape
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageUpsample(nn.Module):
"""
Upsampling block for spatial expansion.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.repeats = factor * out_channels // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
h = h.reshape(B, C // 4, H * 2, W * 2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
B, C, H, W = shortcut.shape
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
return h + shortcut
class HunyuanImageMidBlock(nn.Module):
"""
Middle block for HunyuanImageVAE encoder and decoder.
Args:
in_channels (int): Number of input channels.
num_layers (int): Number of layers.
"""
def __init__(self, in_channels: int, num_layers: int = 1):
super().__init__()
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
attentions = []
for _ in range(num_layers):
attentions.append(HunyuanImageAttentionBlock(in_channels))
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = attn(x)
x = resnet(x)
return x
class HunyuanImageEncoder2D(nn.Module):
r"""
Encoder network that compresses input to latent representation.
Args:
in_channels (int): Number of input channels.
z_channels (int): Number of latent channels.
block_out_channels (list of int): Output channels for each block.
num_res_blocks (int): Number of residual blocks per block.
spatial_compression_ratio (int): Spatial downsampling factor.
non_linearity (str): Type of non-linearity to use. Default is "silu".
downsample_match_channel (bool): Whether to match channels during downsampling.
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
non_linearity: str = "silu",
downsample_match_channel: bool = True,
):
super().__init__()
if block_out_channels[-1] % (2 * z_channels) != 0:
raise ValueError(
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
)
self.in_channels = in_channels
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.spatial_compression_ratio = spatial_compression_ratio
self.group_size = block_out_channels[-1] // (2 * z_channels)
self.nonlinearity = get_activation(non_linearity)
# init block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
block_in_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
# residual blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# downsample block
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if downsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.down_blocks.append(
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# middle blocks
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
# output blocks
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
## downsamples
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(down_block, x)
else:
x = down_block(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.mid_block, x)
else:
x = self.mid_block(x)
## head
B, C, H, W = x.shape
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
x = self.norm_out(x)
x = self.nonlinearity(x)
x = self.conv_out(x)
return x + residual
class HunyuanImageDecoder2D(nn.Module):
r"""
Decoder network that reconstructs output from latent representation.
Args:
z_channels : int
Number of latent channels.
out_channels : int
Number of output channels.
block_out_channels : Tuple[int, ...]
Output channels for each block.
num_res_blocks : int
Number of residual blocks per block.
spatial_compression_ratio : int
Spatial upsampling factor.
upsample_match_channel : bool
Whether to match channels during upsampling.
non_linearity (str): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
upsample_match_channel: bool = True,
non_linearity: str = "silu",
):
super().__init__()
if block_out_channels[0] % z_channels != 0:
raise ValueError(
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
)
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.repeat = block_out_channels[0] // z_channels
self.spatial_compression_ratio = spatial_compression_ratio
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Middle blocks with attention
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
# Upsampling blocks
block_in_channel = block_out_channels[0]
self.up_blocks = nn.ModuleList()
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
for _ in range(self.num_res_blocks + 1):
self.up_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if upsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
block_in_channel = block_out_channel
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid_block, h)
else:
h = self.mid_block(h)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(up_block, h)
else:
h = up_block(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
spatial_compression_ratio: int,
sample_size: int,
scaling_factor: float = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
) -> None:
# fmt: on
super().__init__()
self.encoder = HunyuanImageEncoder2D(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageDecoder2D(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
# Tiling and slicing configuration
self.use_slicing = False
self.use_tiling = False
# Tiling parameters
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_size: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
allow processing larger images.
Args:
tile_sample_min_size (`int`, *optional*):
The minimum size required for a sample to be separated into tiles across the spatial dimension.
tile_overlap_factor (`float`, *optional*):
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
"""
self.use_tiling = True
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode input using spatial tiling strategy.
Args:
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
Returns:
`torch.Tensor`:
The latent representation of the encoded images.
"""
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode latent using spatial tiling strategy.
Args:
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -0,0 +1,934 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageRefinerCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanImageRefinerRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
h = h[:, : h.shape[1] // 2]
# shortcut computation
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
h = torch.cat([h, h], dim=1)
# shortcut computation
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
else:
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageRefinerResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanImageRefinerMidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanImageRefinerDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanImageRefinerDownsampleDCAE(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanImageRefinerUpsampleDCAE(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerEncoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanImageRefinerDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanImageRefinerEncoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageRefinerDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1219,27 +1219,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module):
return hidden_states
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
@@ -805,27 +805,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLMochi(ModelMixin, ConfigMixin):
class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self,
in_channels: int = 15,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
@@ -818,27 +818,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _enable_framewise_encoding(self):
r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
@@ -16,7 +16,7 @@
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
# For more information about the Wan VAE, please refer to:
# - GitHub: https://github.com/Wan-Video/Wan2.1
# - arXiv: https://arxiv.org/abs/2503.20314
# - Paper: https://huggingface.co/papers/2503.20314
from typing import List, Optional, Tuple, Union
@@ -31,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ class QwenImageDecoder3d(nn.Module):
return x
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -763,27 +763,6 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
def _count_conv3d(model):
count = 0
@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
@@ -25,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head
x = self.norm_out(x)
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsampler(x)
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -951,7 +951,7 @@ def unpatchify(x, patch_size):
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
@@ -968,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: List[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -1110,27 +1113,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1355,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1371,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1387,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@@ -25,6 +25,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
class Snake1d(nn.Module):
@@ -291,7 +292,7 @@ class OobleckDecoder(nn.Module):
return hidden_state
class AutoencoderOobleck(ModelMixin, ConfigMixin):
class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
@@ -356,20 +357,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True

Some files were not shown because too many files have changed in this diff Show More