Compare commits

...

55 Commits

Author SHA1 Message Date
Dhruv Nair 98954fc2e1 update 2025-07-28 05:33:00 +02:00
DN6 1262d19d16 update 2025-07-28 08:32:01 +05:30
YiYi Xu 201da97dd0 Merge branch 'main' into custom-code-updates 2025-07-23 10:23:35 -10:00
Aryan f36ba9f094 [modular diffusers] Wan (#11913)
* update
2025-07-23 06:19:40 -10:00
Sayak Paul 1c50a5f7e0 [tests] enforce torch version in the compilation tests. (#11979)
enforce torch version in the compilation tests.
2025-07-23 19:42:46 +05:30
Sayak Paul 7ae6347e33 [docs] update guidance_scale docstring for guidance_distilled models. (#11935)
* update guidance_scale docstring for guidance_distilled models.

* Update pipeline_flux.py

* Update pipeline_flux_control.py

* Update pipeline_flux_kontext.py

* Update pipeline_flux_kontext_inpaint.py

* Update pipeline_sana_sprint.py

* style

* Update pipeline_hidream_image.py

* Update pipeline_chroma.py

* Update pipeline_chroma_img2img.py

* Update pipeline_hunyuan_video.py

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-07-23 17:49:38 +05:30
Aryan 178d32dedd [tests] Add test slices for Wan (#11920)
* update

* fix wan vace test slice

* test

* fix
2025-07-23 17:23:52 +05:30
YiYi Xu ef1e628729 fix style (#11975)
up
2025-07-22 10:25:40 -10:00
Sam Gao 173e1b147d [Examples] Uniform notations in train_flux_lora (#10011)
[Examples] uniform naming notations

since the in parameter `size` represents `args.resolution`, I thus replace the `args.resolution` inside DreamBoothData with `size`. And revise some notations such as `center_crop`.

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-07-22 09:14:00 -10:00
Aryan e46e139f95 Remove logger warnings for attention backends and hard error during runtime instead (#11967)
* update

* update

* update
2025-07-22 20:47:44 +05:30
DN6 4423097b23 update 2025-07-22 19:31:22 +05:30
Yao Matrix 14725164be fix "Expected all tensors to be on the same device, but found at least two devices" error (#11690)
* xx

* fix

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Update model_loading_utils.py

* Update test_models_unet_2d_condition.py

* Update test_models_unet_2d_condition.py

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix comments

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* Update unet_2d_blocks.py

* update

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-22 13:39:24 +02:00
YiYi Xu 638cc035e5 [Modular] update the collection behavior (#11963)
* only remove from the collection
2025-07-21 08:47:07 -10:00
Dhruv Nair 60d1b81023 update 2025-07-21 18:44:44 +02:00
Aryan 9db9be65f3 [tests] Add fast test slices for HiDream-Image (#11953)
update
2025-07-21 07:53:13 +05:30
Aryan d87134ada4 [tests] Add test slices for Cosmos (#11955)
* test

* try fix
2025-07-21 07:52:44 +05:30
Aryan 67a8ec8bf5 [tests] Add test slices for Hunyuan Video (#11954)
update
2025-07-21 07:52:16 +05:30
Chengxi Guo cde02b061b Fix kontext finetune issue when batch size >1 (#11921)
set drop_last to True

Signed-off-by: mymusise <mymusise1@gmail.com>
2025-07-18 19:38:58 -04:00
Sayak Paul 5dc503aa28 [docs] include bp link. (#11952)
* include bp link.

* Update docs/source/en/optimization/fp16.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* resources.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-07-18 22:17:13 +01:00
Steven Liu c6fbcf717b [docs] Update toctree (#11936)
* update

* fix

* feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-18 13:37:04 -07:00
Dhruv Nair b9e99654e1 [Modular] Updates for Custom Pipeline Blocks (#11940)
* update

* update

* update
2025-07-18 15:01:50 +02:00
Sayak Paul 478df933c3 [docs] clarify the mapping between Transformer2DModel and finegrained variants. (#11947)
* clarify the mapping between Transformer2DModel and finegrained variants.

* Update src/diffusers/pipelines/dit/pipeline_dit.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-07-18 08:28:51 +01:00
Aryan 18c8f10f20 [refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)
* update

* update

* add coauthor

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-07-17 17:30:39 +05:30
Tolga Cangöz 7298bdd817 Add SkyReels V2: Infinite-Length Film Generative Model (#11518)
* style

* Fix class name casing for SkyReelsV2 components in multiple files to ensure consistency and correct functionality.

* cleaning

* cleansing

* Refactor `get_timestep_embedding` to move modifications into `SkyReelsV2TimeTextImageEmbedding`.

* Remove unnecessary line break in `get_timestep_embedding` function for cleaner code.

* Remove `skyreels_v2` entry from `_import_structure` and update its initialization to directly assign the list of SkyReelsV2 components.

* cleansing

* Refactor attention processing in `SkyReelsV2AttnProcessor2_0` to always convert query, key, and value to `torch.bfloat16`, simplifying the code and improving clarity.

* Enhance example usage in `pipeline_skyreels_v2_diffusion_forcing.py` by adding VAE initialization and detailed prompt for video generation, improving clarity and usability of the documentation.

* Refactor import structure in `__init__.py` for SkyReelsV2 components and improve formatting in `pipeline_skyreels_v2_diffusion_forcing.py` to enhance code readability and maintainability.

* Update `guidance_scale` parameter in `SkyReelsV2DiffusionForcingPipeline` from 5.0 to 6.0 to enhance video generation quality.

* Update `guidance_scale` parameter in example documentation and class definition of `SkyReelsV2DiffusionForcingPipeline` to ensure consistency and improve video generation quality.

* Update `causal_block_size` parameter in `SkyReelsV2DiffusionForcingPipeline` to default to `None`.

* up

* Fix dtype conversion for `timestep_proj` in `SkyReelsV2Transformer3DModel` to *ensure* correct tensor operations.

* Optimize causal mask generation by replacing repeated tensor with `repeat_interleave` for improved efficiency in `SkyReelsV2Transformer3DModel`.

* style

* Enhance example documentation in `SkyReelsV2DiffusionForcingPipeline` with guidance scale and shift parameters for T2V and I2V. Remove unused `retrieve_latents` function to streamline the code.

* Refactor sample scheduler creation in `SkyReelsV2DiffusionForcingPipeline` to use `deepcopy` for improved state management during inference steps.

* Enhance error handling and documentation in `SkyReelsV2DiffusionForcingPipeline` for `overlap_history` and `addnoise_condition` parameters to improve long video generation guidance.

* Update documentation and progress bar handling in `SkyReelsV2DiffusionForcingPipeline` to clarify asynchronous inference settings and improve progress tracking during denoising steps.

* Refine progress bar calculation in `SkyReelsV2DiffusionForcingPipeline` by rounding the step size to one decimal place for improved readability during denoising steps.

* Update import statements in `SkyReelsV2DiffusionForcingPipeline` documentation for improved clarity and organization.

* Refactor progress bar handling in `SkyReelsV2DiffusionForcingPipeline` to use total steps instead of calculated step size.

* update templates for i2v, v2v

* Add `retrieve_latents` function to streamline latent retrieval in `SkyReelsV2DiffusionForcingPipeline`. Update video latent processing to utilize this new function for improved clarity and maintainability.

* Add `retrieve_latents` function to both i2v and v2v pipelines for consistent latent retrieval. Update video latent processing to utilize this function, enhancing clarity and maintainability across the SkyReelsV2DiffusionForcingPipeline implementations.

* Remove redundant ValueError for `overlap_history` in `SkyReelsV2DiffusionForcingPipeline` to streamline error handling and improve user guidance for long video generation.

* Update default video dimensions and flow matching scheduler parameter in `SkyReelsV2DiffusionForcingPipeline` to enhance video generation capabilities.

* Refactor `SkyReelsV2DiffusionForcingPipeline` to support Image-to-Video (i2v) generation. Update class name, add image encoding functionality, and adjust parameters for improved video generation. Enhance error handling for image inputs and update documentation accordingly.

* Improve organization for image-last_image condition.

* Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` to improve latent preparation and video condition handling integration.

* style

* style

* Add example usage of PIL for image input in `SkyReelsV2DiffusionForcingImageToVideoPipeline` documentation.

* Refactor `SkyReelsV2DiffusionForcingPipeline` to `SkyReelsV2DiffusionForcingVideoToVideoPipeline`, enhancing support for Video-to-Video (v2v) generation. Introduce video input handling, update latent preparation logic, and improve error handling for input parameters.

* Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` by removing the `image_encoder` and `image_processor` dependencies. Update the CPU offload sequence accordingly.

* Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` to enhance latent preparation logic and condition handling. Update image input type to `Optional`, streamline video condition processing, and improve handling of `last_image` during latent generation.

* Enhance `SkyReelsV2DiffusionForcingPipeline` by refining latent preparation for long video generation. Introduce new parameters for video handling, overlap history, and causal block size. Update logic to accommodate both short and long video scenarios, ensuring compatibility and improved processing.

* refactor

* fix num_frames

* fix prefix_video_latents

* up

* refactor

* Fix typo in scheduler method call within `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to ensure proper noise scaling during latent generation.

* up

* Enhance `SkyReelsV2DiffusionForcingImageToVideoPipeline` by adding support for `last_image` parameter and refining latent frame calculations. Update preprocessing logic.

* add statistics

* Refine latent frame handling in `SkyReelsV2DiffusionForcingImageToVideoPipeline` by correcting variable names and reintroducing latent mean and standard deviation calculations. Update logic for frame preparation and sampling to ensure accurate video generation.

* up

* refactor

* up

* Refactor `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to improve latent handling by enforcing tensor input for video, updating frame preparation logic, and adjusting default frame count. Enhance preprocessing and postprocessing steps for better integration.

* style

* fix vae output indexing

* upup

* up

* Fix tensor concatenation and repetition logic in `SkyReelsV2DiffusionForcingImageToVideoPipeline` to ensure correct dimensionality for video conditions and latent conditions.

* Refactor latent retrieval logic in `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to handle tensor dimensions more robustly, ensuring compatibility with both 3D and 4D video inputs.

* Enhance logging in `SkyReelsV2DiffusionForcing` pipelines by adding iteration print statements for better debugging. Clean up unused code related to prefix video latents length calculation in `SkyReelsV2DiffusionForcingImageToVideoPipeline`.

* Update latent handling in `SkyReelsV2DiffusionForcingImageToVideoPipeline` to conditionally set latents based on video iteration state, improving flexibility for video input processing.

* Refactor `SkyReelsV2TimeTextImageEmbedding` to utilize `get_1d_sincos_pos_embed_from_grid` for timestep projection.

* Enhance `get_1d_sincos_pos_embed_from_grid` function to include an optional parameter `flip_sin_to_cos` for flipping sine and cosine embeddings, improving flexibility in positional embedding generation.

* Update timestep projection in `SkyReelsV2TimeTextImageEmbedding` to include `flip_sin_to_cos` parameter, enhancing the flexibility of time embedding generation.

* Refactor tensor type handling in `SkyReelsV2AttnProcessor2_0` and `SkyReelsV2TransformerBlock` to ensure consistent use of `torch.float32` and `torch.bfloat16`, improving integration.

* Update tensor type in `SkyReelsV2RotaryPosEmbed` to use `torch.float32` for frequency calculations, ensuring consistency in data types across the model.

* Refactor `SkyReelsV2TimeTextImageEmbedding` to utilize automatic mixed precision for timestep projection.

* down

* down

* style

* Add debug tensor tracking to `SkyReelsV2Transformer3DModel` for enhanced debugging and output analysis; update `Transformer2DModelOutput` to include debug tensors.

* up

* Refactor indentation in `SkyReelsV2AttnProcessor2_0` to improve code readability and maintain consistency in style.

* Convert query, key, and value tensors to bfloat16 in `SkyReelsV2AttnProcessor2_0` for improved performance.

* Add debug print statements in `SkyReelsV2TransformerBlock` to track tensor shapes and values for improved debugging and analysis.

* debug

* debug

* Remove commented-out debug tensor tracking from `SkyReelsV2TransformerBlock`

* Add functionality to save processed video latents as a Safetensors file in `SkyReelsV2DiffusionForcingPipeline`.

* up

* Add functionality to save output latents as a Safetensors file in `SkyReelsV2DiffusionForcingPipeline`.

* up

* Remove additional commented-out debug tensor tracking from `SkyReelsV2TransformerBlock` and `SkyReelsV2Transformer3DModel` for cleaner code.

* style

* cleansing

* Update example documentation and parameters in `SkyReelsV2Pipeline`. Adjusted example code for loading models, modified default values for height, width, num_frames, and guidance_scale, and improved output video quality settings.

* Update shift parameter in example documentation and default values across SkyReels V2 pipelines. Adjusted shift values for I2V from 3.0 to 5.0 and updated related example code for consistency.

* Update example documentation in SkyReels V2 pipelines to include available model options and update model references for loading. Adjusted model names to reflect the latest versions across I2V, V2V, and T2V pipelines.

* Add test templates

* style

* Add docs template

* Add SkyReels V2 Diffusion Forcing Video-to-Video Pipeline to imports

* style

* fix-copies

* convert i2v 1.3b

* Update transformer configuration to include `image_dim` for SkyReels V2 models and refactor imports to use `SkyReelsV2Transformer3DModel`.

* Refactor transformer import in SkyReels V2 pipeline to use `SkyReelsV2Transformer3DModel` for consistency.

* Update transformer configuration in SkyReels V2 to increase `in_channels` from 16 to 36 for i2v conf.

* Update transformer configuration in SkyReels V2 to set `added_kv_proj_dim` values for different model types.

* up

* up

* up

* Add SkyReelsV2Pipeline support for T2V model type in conversion script

* upp

* Refactor model type checks in conversion script to use substring matching for improved flexibility

* upp

* Fix shard path formatting in conversion script to accommodate varying model types by dynamically adjusting zero padding.

* Update sharded safetensors loading logic in conversion script to use substring matching for model directory checks

* Update scheduler parameters in SkyReels V2 test files for consistency across image and video pipelines

* Refactor conversion script to initialize text encoder, tokenizer, and scheduler for SkyReels pipelines, enhancing model integration

* style

* Update documentation for SkyReels-V2, introducing the Infinite-length Film Generative model, enhancing text-to-video generation examples, and updating model references throughout the API documentation.

* Add SkyReelsV2Transformer3DModel and FlowMatchUniPCMultistepScheduler documentation, updating TOC and introducing new model and scheduler files.

* style

* Update documentation for SkyReelsV2DiffusionForcingPipeline to correct flow matching scheduler parameter for I2V from 3.0 to 5.0, ensuring clarity in usage examples.

* Add documentation for causal_block_size parameter in SkyReelsV2DF pipelines, clarifying its role in asynchronous inference.

* Simplify min_ar_step calculation in SkyReelsV2DiffusionForcingPipeline to improve clarity.

* style and fix-copies

* style

* Add documentation for SkyReelsV2Transformer3DModel

Introduced a new markdown file detailing the SkyReelsV2Transformer3DModel, including usage instructions and model output specifications.

* Update test configurations for SkyReelsV2 pipelines

- Adjusted `in_channels` from 36 to 16 in `test_skyreels_v2_df_image_to_video.py`.
- Added new parameters: `overlap_history`, `num_frames`, and `base_num_frames` in `test_skyreels_v2_df_video_to_video.py`.
- Updated expected output shape in video tests from (17, 3, 16, 16) to (41, 3, 16, 16).

* Refines SkyReelsV2DF test parameters

* Update src/diffusers/models/modeling_outputs.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Refactor `grid_sizes` processing by using already-calculated post-patch parameters to simplify

* Update docs/source/en/api/pipelines/skyreels_v2.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Refactor parameter naming for diffusion forcing in SkyReelsV2 pipelines

- Changed `flag_df` to `enable_diffusion_forcing` for clarity in the SkyReelsV2Transformer3DModel and associated pipelines.
- Updated all relevant method calls to reflect the new parameter name.

* Revert _toctree.yml to adjust section expansion states

* style

* Update docs/source/en/api/models/skyreels_v2_transformer_3d.md

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Add copying label to SkyReelsV2ImageEmbedding from WanImageEmbedding.

* Refactor transformer block processing in SkyReelsV2Transformer3DModel

- Ensured proper handling of hidden states during both gradient checkpointing and standard processing.

* Update SkyReels V2 documentation to remove VRAM requirement and streamline imports

- Removed the mention of ~13GB VRAM requirement for the SkyReels-V2 model.
- Simplified import statements by removing unused `load_image` import.

* Add SkyReelsV2LoraLoaderMixin for loading and managing LoRA layers in SkyReelsV2Transformer3DModel

- Introduced SkyReelsV2LoraLoaderMixin class to handle loading, saving, and fusing of LoRA weights specific to the SkyReelsV2 model.
- Implemented methods for state dict management, including compatibility checks for various LoRA formats.
- Enhanced functionality for loading weights with options for low CPU memory usage and hotswapping.
- Added detailed docstrings for clarity on parameters and usage.

* Update SkyReelsV2 documentation and loader mixin references

- Corrected the documentation to reference the new `SkyReelsV2LoraLoaderMixin` for loading LoRA weights.
- Updated comments in the `SkyReelsV2LoraLoaderMixin` class to reflect changes in model references from `WanTransformer3DModel` to `SkyReelsV2Transformer3DModel`.

* Enhance SkyReelsV2 integration by adding SkyReelsV2LoraLoaderMixin references

- Added `SkyReelsV2LoraLoaderMixin` to the documentation and loader imports for improved LoRA weight management.
- Updated multiple pipeline classes to inherit from `SkyReelsV2LoraLoaderMixin` instead of `WanLoraLoaderMixin`.

* Update SkyReelsV2 model references in documentation

- Replaced placeholder model paths with actual paths for SkyReels-V2 models in multiple pipeline files.
- Ensured consistency across the documentation for loading models in the SkyReelsV2 pipelines.

* style

* fix-copies

* Refactor `fps_projection` in `SkyReelsV2Transformer3DModel`

- Replaced the sequential linear layers for `fps_projection` with a `FeedForward` layer using `SiLU` activation for better integration.

* Update docs

* Refactor video processing in SkyReelsV2DiffusionForcingPipeline

- Renamed parameters for clarity: `video` to `video_latents` and `overlap_history` to `overlap_history_latent_frames`.
- Updated logic for handling long video generation, including adjustments to latent frame calculations and accumulation.
- Consolidated handling of latents for both long and short video generation scenarios.
- Final decoding step now consistently converts latents to pixels, ensuring proper output format.

* Update activation function in `fps_projection` of `SkyReelsV2Transformer3DModel`

- Changed activation function from `silu` to `linear-silu` in the `fps_projection` layer for improved performance and integration.

* Add fps_projection layer renaming in convert_skyreelsv2_to_diffusers.py

- Updated key mappings for the `fps_projection` layer to align with new naming conventions, ensuring consistency in model integration.

* Fix fps_projection assignment in SkyReelsV2Transformer3DModel

- Corrected the assignment of the `fps_projection` layer to ensure it is properly cast to the appropriate data type, enhancing model functionality.

* Update _keep_in_fp32_modules in SkyReelsV2Transformer3DModel

- Added `fps_projection` to the list of modules that should remain in FP32 precision, ensuring proper handling of data types during model operations.

* Remove integration test classes from SkyReelsV2 test files

- Deleted the `SkyReelsV2DiffusionForcingPipelineIntegrationTests` and `SkyReelsV2PipelineIntegrationTests` classes along with their associated setup, teardown, and test methods, as they were not implemented and not needed for current testing.

* style

* Refactor: Remove hardcoded `torch.bfloat16` cast in attention

* Refactor: Simplify data type handling in transformer model

Removes unnecessary data type conversions for the FPS embedding and timestep projection.

This change simplifies the forward pass by relying on the inherent data types of the tensors.

* Refactor: Remove `fps_projection` from `_keep_in_fp32_modules` in `SkyReelsV2Transformer3DModel`

* Update src/diffusers/models/transformers/transformer_skyreels_v2.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Refactor: Remove unused flags and simplify attention mask handling in SkyReelsV2AttnProcessor2_0 and SkyReelsV2Transformer3DModel

Refactor: Simplify causal attention logic in SkyReelsV2

Removes the `flag_causal_attention` and `_flag_ar_attention` flags to simplify the implementation.

The decision to apply a causal attention mask is now based directly on the `num_frame_per_block` configuration, eliminating redundant flags and conditional checks. This streamlines the attention mechanism and simplifies the `set_ar_attention` methods.

* Refactor: Clarify variable names for latent frames

Renames `base_num_frames` to `base_latent_num_frames` to make it explicit that the variable refers to the number of frames in the latent space.

This change improves code readability and reduces potential confusion between latent frames and decoded video frames.

The `num_frames` parameter in `generate_timestep_matrix` is also renamed to `num_latent_frames` for consistency.

* Enhance documentation: Add detailed docstring for timestep matrix generation in SkyReelsV2DiffusionForcingPipeline

* Docs: Clarify long video chunking in pipeline docstring

Improves the explanation of long video processing within the pipeline's docstring.

The update replaces the abstract description with a concrete example, illustrating how the sliding window mechanism works with overlapping chunks. This makes the roles of `base_num_frames` and `overlap_history` clearer for users.

* Docs: Move visual demonstration and processing details for SkyReelsV2DiffusionForcingPipeline to docs page from the code

* Docs: Update asynchronous processing timeline and examples for long video handling in SkyReels-V2 documentation

* Enhance timestep matrix generation documentation and logic for synchronous/asynchronous video processing

* Update timestep matrix documentation and enhance analysis for clarity in SkyReelsV2DiffusionForcingPipeline

* Docs: Update visual demonstration section and add detailed step matrix construction example for asynchronous processing in SkyReelsV2DiffusionForcingPipeline

* style

* fix-copies

* Refactor parameter names for clarity in SkyReelsV2DiffusionForcingImageToVideoPipeline and SkyReelsV2DiffusionForcingVideoToVideoPipeline

* Refactor: Avoid VAE roundtrip in long video generation

Improves performance and quality for long video generation by operating entirely in latent space during the iterative generation process.

Instead of decoding latents to video and then re-encoding the overlapping section for the next chunk, this change passes the generated latents directly between iterations.

This avoids a computationally expensive and potentially lossy VAE decode/encode cycle within the loop. The full video is now decoded only once from the accumulated latents at the end of the process.

* Refactor: Rename prefix_video_latents_length to prefix_video_latents_frames for clarity

* Refactor: Rename num_latent_frames to current_num_latent_frames for clarity in SkyReelsV2DiffusionForcingImageToVideoPipeline

* Refactor: Enhance long video generation logic and improve latent handling in SkyReelsV2DiffusionForcingImageToVideoPipeline

Refactor: Unify video generation and pass latents directly

Unifies the separate code paths for short and long video generation into a single, streamlined loop.

This change eliminates the inefficient decode-encode cycle during long video generation. Instead of converting latents to pixel-space video between chunks, the pipeline now passes the generated latents directly to the next iteration.

This improves performance, avoids potential quality loss from intermediate VAE steps, and enhances code maintainability by removing significant duplication.

* style

* Refactor: Remove overlap_history parameter and streamline long video generation logic in SkyReelsV2DiffusionForcingImageToVideoPipeline

Refactor: Streamline long video generation logic

Removes the `overlap_history` parameter and simplifies the conditioning process for long video generation.

This change avoids a redundant VAE encoding step by directly using latent frames from the previous chunk for conditioning. It also moves image preprocessing outside the main generation loop to prevent repeated computations and clarifies the handling of prefix latents.

* style

* Refactor latent handling in i2v diffusion forcing pipeline

Improves the latent conditioning and accumulation logic within the image-to-video diffusion forcing loop.

- Corrects the splitting of the initial conditioning tensor to robustly handle both even and odd lengths.
- Simplifies how latents are accumulated across iterations for long video generation.
- Ensures the final latents are trimmed correctly before decoding only when a `last_image` is provided.

* Refactor: Remove overlap_history parameter from SkyReelsV2DiffusionForcingImageToVideoPipeline

* Refactor: Adjust video_latents parameter handling in prepare_latents method

* style

* Refactor: Update long video iteration print statements for clarity

* Fix: Update transformer config with dynamic causal block size

Updates the SkyReelsV2 pipelines to correctly set the `causal_block_size` in the transformer's configuration when it's provided during a pipeline call.

This ensures the model configuration reflects the user's specified setting for the inference run. The `set_ar_attention` method is also renamed to `_set_ar_attention` to mark it as an internal helper.

* style

* Refactor: Adjust video input size and expected output shape in inference test

* Refactor: Rename video variables for clarity in SkyReelsV2DiffusionForcingVideoToVideoPipeline

* Docs: Clarify time embedding logic in SkyReelsV2

Adds comments to explain the handling of different time embedding tensor dimensions.

A 2D tensor is used for standard models with a single time embedding per batch, while a 3D tensor is used for Diffusion Forcing models where each frame has its own time embedding. This clarifies the expected input for different model variations.

* Docs: Update SkyReels V2 pipeline examples

Updates the docstring examples for the SkyReels V2 pipelines to reflect current best practices and API changes.

- Removes the `shift` parameter from pipeline call examples, as it is now configured directly on the scheduler.
- Replaces the `set_ar_attention` method call with the `causal_block_size` argument in the pipeline call for diffusion forcing examples.
- Adjusts recommended parameters for I2V and V2V examples, including inference steps, guidance scale, and `ar_step`.

* Refactor: Remove `shift` parameter from SkyReelsV2 pipelines

Removes the `shift` parameter from the call signature of all SkyReelsV2 pipelines.

This parameter is a scheduler-specific configuration and should be set directly on the scheduler during its initialization, rather than being passed at runtime through the pipeline. This change simplifies the pipeline API.

Usage examples are updated to reflect that the `shift` value should now be passed when creating the `FlowMatchUniPCMultistepScheduler`.

* Refactors SkyReelsV2 image-to-video tests and adds last image case

Simplifies the test suite by removing a duplicated test class and streamlining the dummy component and input generation.

Adds a new test to verify the pipeline's behavior when a `last_image` is provided as input for conditioning.

* test: Add image components to SkyReelsV2 pipeline test

Adds the `image_encoder` and `image_processor` to the test components for the image-to-video pipeline.

Also replaces a hardcoded value for the positional embedding sequence length with a more descriptive calculation, improving clarity.

* test: Add callback configuration test for SkyReelsV2DiffusionForcingVideoToVideoPipeline

test: Add callback test for SkyReelsV2DFV2V pipeline

Adds a test to validate the callback functionality for the `SkyReelsV2DiffusionForcingVideoToVideoPipeline`.

This test confirms that `callback_on_step_end` is invoked correctly and can modify the pipeline's state during inference. It uses a callback to dynamically increase the `guidance_scale` and asserts that the final value is as expected.

The implementation correctly accounts for the nested denoising loops present in diffusion forcing pipelines.

* style

* fix: Update image_encoder type to CLIPVisionModelWithProjection in SkyReelsV2ImageToVideoPipeline

* UP

* Add conversion support for SkyReels-V2-FLF2V models

Adds configurations for three new FLF2V model variants (1.3B-540P, 14B-540P, and 14B-720P) to the conversion script.

This change also introduces specific handling to zero out the image positional embeddings for these models and updates the main script to correctly initialize the image-to-video pipeline.

* Docs: Update and simplify SkyReels V2 usage examples

Simplifies the text-to-video example by removing the manual group offloading configuration, making it more straightforward.

Adds comments to pipeline parameters to clarify their purpose and provides guidance for different resolutions and long video generation.

Introduces a new section with a code example for the video-to-video pipeline.

* style

* docs: Add SkyReels-V2 FLF2V 1.3B model to supported models list

* docs: Update SkyReels-V2 documentation

* Move the initialization of the `gradient_checkpointing` attribute to its suggested location.

* Refactor: Use logger for long video progress messages

Replaces `print()` calls with `logger.debug()` for reporting progress during long video generation in SkyReelsV2DF pipelines.

This change reduces console output verbosity for standard runs while allowing developers to view progress by enabling debug-level logging.

* Refactor SkyReelsV2 timestep embedding into a module

Extract the sinusoidal timestep embedding logic into a new `SkyReelsV2Timesteps` `nn.Module`.

This change encapsulates the embedding generation, which simplifies the `SkyReelsV2TimeTextImageEmbedding` class and improves code modularity.

* Fix: Preserve original shape in timestep embeddings

Reshapes the timestep embedding tensor to match the original input shape.

This ensures that batched timestep inputs retain their batch dimension after embedding, preventing potential shape mismatches.

* style

* Refactor: Move SkyReelsV2Timesteps to model file

Colocates the `SkyReelsV2Timesteps` class with the SkyReelsV2 transformer model.

This change moves model-specific timestep embedding logic from the general embeddings module to the transformer's own file, improving modularity and making the model more self-contained.

* Refactor parameter dtype retrieval to use utility function

Replaces manual parameter iteration with the `get_parameter_dtype` helper to determine the time embedder's data type.

This change improves code readability and centralizes the logic.

* Add comments to track the tensor shape transformations

* Add copied froms

* style

* fix-copies

* up

* Remove FlowMatchUniPCMultistepScheduler

Deletes the `FlowMatchUniPCMultistepScheduler` as it is no longer being used.

* Refactor: Replace FlowMatchUniPC scheduler with UniPC

Removes the `FlowMatchUniPCMultistepScheduler` and integrates its functionality into the existing `UniPCMultistepScheduler`.

This consolidation is achieved by using the `use_flow_sigmas=True` parameter in `UniPCMultistepScheduler`, simplifying the scheduler API and reducing code duplication. All usages, documentation, and tests are updated accordingly.

* style

* Remove text_encoder parameter from SkyReelsV2DiffusionForcingPipeline initialization

* Docs: Rename `pipe` to `pipeline` in SkyReels examples

Updates the variable name from `pipe` to `pipeline` across all SkyReels V2 documentation examples. This change improves clarity and consistency.

* Fix: Rename shift parameter to flow_shift in SkyReels-V2 examples

* Fix: Rename shift parameter to flow_shift in example documentation across SkyReels-V2 files

* Fix: Rename shift parameter to flow_shift in UniPCMultistepScheduler initialization across SkyReels test files

* Removes unused generator argument from scheduler step

The `generator` parameter is not used by the scheduler's `step` method within the SkyReelsV2 diffusion forcing pipelines. This change removes the unnecessary argument from the method call for code clarity and consistency.

* Fix: Update time_embedder_dtype assignment to use the first parameter's dtype in SkyReelsV2TimeTextImageEmbedding

* style

* Refactor: Use get_parameter_dtype utility function

Replaces manual parameter iteration with the `get_parameter_dtype` helper.

* Fix: Prevent (potential) error in parameter dtype check

Adds a check to ensure the `_keep_in_fp32_modules` attribute exists on a parameter before it is accessed.

This prevents a potential `AttributeError`, making the utility function more robust when used with models that do not define this attribute.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
2025-07-16 08:24:41 -10:00
Sayak Paul 9c13f86579 [training] add an offload utility that can be used as a context manager. (#11775)
* add an offload utility that can be used as a context manager.

* update

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-07-16 09:09:13 +01:00
G.O.D 5c5209720e enable flux pipeline compatible with unipc and dpm-solver (#11908)
* Update pipeline_flux.py

have flux pipeline work with unipc/dpm schedulers

* clean code

* Update scheduling_dpmsolver_multistep.py

* Update scheduling_unipc_multistep.py

* Update pipeline_flux.py

* Update scheduling_deis_multistep.py

* Update scheduling_dpmsolver_singlestep.py

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-07-15 17:49:57 -10:00
Álvaro Somoza aa14f090f8 [ControlnetUnion] Propagate #11888 to img2img (#11929)
img2img fixes
2025-07-15 21:41:35 -04:00
Guoqing Zhu c5d6e0b537 Fixed bug: Uncontrolled recursive calls that caused an infinite loop when loading certain pipelines containing Transformer2DModel (#11923)
* fix a bug about loop call

* fix a bug about loop call

* ruff format

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-07-15 14:58:37 -10:00
lostdisc 39831599f1 Remove forced float64 from onnx stable diffusion pipelines (#11054)
* Update pipeline_onnx_stable_diffusion.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion_inpaint.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion_upscale.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion.py with comment for previous commit

Added comment on purpose of init_noise_sigma.  This comment exists in related scripts that use the same line of code, but it was missing here.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-07-15 14:57:28 -10:00
Aryan b73c738392 Remove device synchronization when loading weights (#11927)
* update

* make style
2025-07-15 21:40:57 +05:30
Aryan 06fd427797 [tests] Improve Flux tests (#11919)
update
2025-07-15 10:47:41 +05:30
dependabot[bot] 48a551251d Bump aiohttp from 3.10.10 to 3.12.14 in /examples/server (#11924)
Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.10.10 to 3.12.14.
- [Release notes](https://github.com/aio-libs/aiohttp/releases)
- [Changelog](https://github.com/aio-libs/aiohttp/blob/master/CHANGES.rst)
- [Commits](https://github.com/aio-libs/aiohttp/compare/v3.10.10...v3.12.14)

---
updated-dependencies:
- dependency-name: aiohttp
  dependency-version: 3.12.14
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-15 09:15:57 +05:30
Hengyue-Bi 6398fbc391 Fix: Align VAE processing in ControlNet SD3 training with inference (#11909)
Fix: Apply vae_shift_factor in ControlNet SD3 training
2025-07-14 14:54:38 -04:00
Colle 3c8b67b371 Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)
Flux: pass joint_attention_kwargs when gradient_checkpointing
2025-07-11 08:35:18 -10:00
Steven Liu 9feb946432 [docs] torch.compile blog post (#11837)
* add blog post

* feedback

* feedback
2025-07-11 10:29:40 -07:00
Aryan c90352754a Speedup model loading by 4-5x (#11904)
* update

* update

* update

* pin accelerate version

* add comment explanations

* update docstring

* make style

* non_blocking does not matter for dtype cast

* _empty_cache -> clear_cache

* update

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/diffusers/models/model_loading_utils.py

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-07-11 21:43:53 +05:30
Sayak Paul 7a935a0bbe [tests] Unify compilation + offloading tests in quantization (#11910)
* unify the quant compile + offloading tests.

* fix

* update
2025-07-11 17:02:29 +05:30
chenxiao 941b7fc084 Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)
* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2025-07-10 11:51:05 -10:00
Álvaro Somoza 76a62ac9cc [ControlnetUnion] Multiple Fixes (#11888)
fixes

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-07-10 14:35:28 -04:00
Sayak Paul 1c6ab9e900 [utils] account for MPS when available in get_device(). (#11905)
* account for MPS when available in get_device().

* fix
2025-07-10 13:30:54 +05:30
Sayak Paul 265840a098 [LoRA] fix: disabling hooks when loading loras. (#11896)
fix: disabling hooks when loading loras.
2025-07-10 10:30:10 +05:30
dependabot[bot] 9f4d997d8f Bump torch from 2.4.1 to 2.7.0 in /examples/server (#11429)
Bumps [torch](https://github.com/pytorch/pytorch) from 2.4.1 to 2.7.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v2.4.1...v2.7.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.7.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-07-10 09:24:10 +05:30
Sayak Paul b41abb2230 [quant] QoL improvements for pipeline-level quant config (#11876)
* add repr for pipelinequantconfig.

* update
2025-07-10 08:53:01 +05:30
YiYi Xu f33b89bafb The Modular Diffusers (#9672)
adding modular diffusers as experimental feature 

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-09 16:00:28 -10:00
Álvaro Somoza 48a6d29550 [SD3] CFG Cutoff fix and official callback (#11890)
fix and official callback

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-07-09 14:31:11 -04:00
Sayak Paul 2d3d376bc0 Fix unique memory address when doing group-offloading with disk (#11767)
* fix memory address problem

* add more tests

* updates

* updates

* update

* _group_id = group_id

* update

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* fix

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-07-09 21:29:34 +05:30
Sébastien Iooss db715e2c8c feat: add multiple input image support in Flux Kontext (#11880)
* feat: add multiple input image support in Flux Kontext

* move model to community

* fix linter
2025-07-09 11:09:59 -04:00
Sayak Paul 754fe85cac [tests] add compile + offload tests for GGUF. (#11740)
* add compile + offload tests for GGUF.

* quality

* add init.

* prop.

* change to flux.

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-07-09 13:42:13 +05:30
Sayak Paul cc1f9a2ce3 [tests] mark the wanvace lora tester flaky (#11883)
* mark wanvace lora tests as flaky

* ability to apply is_flaky at a class-level

* update

* increase max_attempt.

* increase attemtp.
2025-07-09 13:27:15 +05:30
Sayak Paul 737d7fc3b0 [tests] Remove more deprecated tests (#11895)
* remove k diffusion tests

* remove script
2025-07-09 13:10:44 +05:30
Sayak Paul be23f7df00 [Docker] update doc builder dockerfile to include quant libs. (#11728)
update doc builder dockerfile to include quant libs.
2025-07-09 12:27:22 +05:30
Sayak Paul 86becea77f Pin k-diffusion for CI (#11894)
* remove k-diffusion as we don't use it from the core.

* Revert "remove k-diffusion as we don't use it from the core."

This reverts commit 8bc86925a0.

* pin k-diffusion
2025-07-09 12:17:45 +05:30
Dhruv Nair 7e3bf4aff6 [CI] Speed up GPU PR Tests (#11887)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-09 11:00:23 +05:30
shm4r7 de043c6044 Update chroma.md (#11891)
Fix typo in Inference example code
2025-07-09 09:58:38 +05:30
Sayak Paul 4c20624cc6 [tests] annotate compilation test classes with bnb (#11715)
annotate compilation test classes with bnb
2025-07-09 09:24:52 +05:30
182 changed files with 31132 additions and 2374 deletions
+2 -1
View File
@@ -13,6 +13,7 @@ on:
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
- "examples/**/*.py"
workflow_dispatch:
concurrency:
@@ -188,7 +189,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
max-parallel: 2
max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
+5 -1
View File
@@ -47,6 +47,10 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
tensorboard \
transformers \
matplotlib \
setuptools==69.5.1
setuptools==69.5.1 \
bitsandbytes \
torchao \
gguf \
optimum-quanto
CMD ["/bin/bash"]
+207 -172
View File
@@ -1,36 +1,39 @@
- sections:
- title: Get started
sections:
- local: index
title: 🧨 Diffusers
title: Diffusers
- local: installation
title: Installation
- local: quicktour
title: Quicktour
- local: stable_diffusion
title: Effective and efficient diffusion
- local: installation
title: Installation
title: Get started
- sections:
- local: tutorials/tutorial_overview
title: Overview
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: tutorials/autopipeline
title: AutoPipeline
- local: tutorials/basic_training
title: Train a diffusion model
title: Tutorials
- sections:
- title: DiffusionPipeline
isExpanded: false
sections:
- local: using-diffusers/loading
title: Load pipelines
- local: tutorials/autopipeline
title: AutoPipeline
- local: using-diffusers/custom_pipeline_overview
title: Load community pipelines and components
- local: using-diffusers/callback
title: Pipeline callbacks
- local: using-diffusers/reusing_seeds
title: Reproducible pipelines
- local: using-diffusers/schedulers
title: Load schedulers and models
- local: using-diffusers/scheduler_features
title: Scheduler features
- local: using-diffusers/other-formats
title: Model files and layouts
- local: using-diffusers/push_to_hub
title: Push files to the Hub
title: Load pipelines and adapters
- sections:
- title: Adapters
isExpanded: false
sections:
- local: tutorials/using_peft_for_inference
title: LoRA
- local: using-diffusers/ip_adapter
@@ -43,25 +46,12 @@
title: DreamBooth
- local: using-diffusers/textual_inversion_inference
title: Textual inversion
title: Adapters
- title: Inference
isExpanded: false
- sections:
- local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- local: using-diffusers/conditional_image_generation
title: Text-to-image
- local: using-diffusers/img2img
title: Image-to-image
- local: using-diffusers/inpaint
title: Inpainting
- local: using-diffusers/text-img2vid
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
title: Generative tasks
- sections:
- local: using-diffusers/overview_techniques
title: Overview
sections:
- local: using-diffusers/weighted_prompts
title: Prompt techniques
- local: using-diffusers/create_a_server
title: Create a server
- local: using-diffusers/batched_inference
@@ -76,14 +66,38 @@
title: Reproducible pipelines
- local: using-diffusers/image_quality
title: Controlling image quality
- local: using-diffusers/weighted_prompts
title: Prompt techniques
title: Inference techniques
- sections:
- local: advanced_inference/outpaint
title: Outpainting
title: Advanced inference
- sections:
- title: Inference optimization
isExpanded: false
sections:
- local: optimization/fp16
title: Accelerate inference
- local: optimization/cache
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compile and offloading quantized models
- title: Community optimizations
sections:
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
title: Token merging
- local: optimization/deepcache
title: DeepCache
- local: optimization/tgate
title: TGATE
- local: optimization/xdit
title: xDiT
- local: optimization/para_attn
title: ParaAttention
- title: Hybrid Inference
isExpanded: false
sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
@@ -92,8 +106,110 @@
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
- sections:
- title: Modular Diffusers
isExpanded: false
sections:
- local: modular_diffusers/overview
title: Overview
- local: modular_diffusers/modular_pipeline
title: Modular Pipeline
- local: modular_diffusers/components_manager
title: Components Manager
- local: modular_diffusers/modular_diffusers_states
title: Modular Diffusers States
- local: modular_diffusers/pipeline_block
title: Pipeline Block
- local: modular_diffusers/sequential_pipeline_blocks
title: Sequential Pipeline Blocks
- local: modular_diffusers/loop_sequential_pipeline_blocks
title: Loop Sequential Pipeline Blocks
- local: modular_diffusers/auto_pipeline_blocks
title: Auto Pipeline Blocks
- local: modular_diffusers/end_to_end_guide
title: End-to-End Example
- title: Training
isExpanded: false
sections:
- local: training/overview
title: Overview
- local: training/create_dataset
title: Create a dataset for training
- local: training/adapt_a_model
title: Adapt a model to a new task
- local: tutorials/basic_training
title: Train a diffusion model
- title: Models
sections:
- local: training/unconditional_training
title: Unconditional image generation
- local: training/text2image
title: Text-to-image
- local: training/sdxl
title: Stable Diffusion XL
- local: training/kandinsky
title: Kandinsky 2.2
- local: training/wuerstchen
title: Wuerstchen
- local: training/controlnet
title: ControlNet
- local: training/t2i_adapters
title: T2I-Adapters
- local: training/instructpix2pix
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
- title: Methods
sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
title: DreamBooth
- local: training/lora
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
- local: training/lcm_distill
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
- title: Quantization
isExpanded: false
sections:
- local: quantization/overview
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
- title: Model accelerators and hardware
isExpanded: false
sections:
- local: using-diffusers/stable_diffusion_jax_how_to
title: JAX/Flax
- local: optimization/onnx
title: ONNX
- local: optimization/open_vino
title: OpenVINO
- local: optimization/coreml
title: Core ML
- local: optimization/mps
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
- title: Specific pipeline examples
isExpanded: false
sections:
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -118,106 +234,30 @@
title: Stable Video Diffusion
- local: using-diffusers/marigold_usage
title: Marigold Computer Vision
title: Specific pipeline examples
- sections:
- local: training/overview
title: Overview
- local: training/create_dataset
title: Create a dataset for training
- local: training/adapt_a_model
title: Adapt a model to a new task
- isExpanded: false
- title: Resources
isExpanded: false
sections:
- title: Task recipes
sections:
- local: training/unconditional_training
- local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- local: training/text2image
- local: using-diffusers/conditional_image_generation
title: Text-to-image
- local: training/sdxl
title: Stable Diffusion XL
- local: training/kandinsky
title: Kandinsky 2.2
- local: training/wuerstchen
title: Wuerstchen
- local: training/controlnet
title: ControlNet
- local: training/t2i_adapters
title: T2I-Adapters
- local: training/instructpix2pix
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
title: Models
- isExpanded: false
sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
title: DreamBooth
- local: training/lora
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
- local: training/lcm_distill
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
title: Methods
title: Training
- sections:
- local: quantization/overview
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
title: Accelerate inference
- local: optimization/cache
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compile and offloading quantized models
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
title: Token merging
- local: optimization/deepcache
title: DeepCache
- local: optimization/tgate
title: TGATE
- local: optimization/xdit
title: xDiT
- local: optimization/para_attn
title: ParaAttention
- sections:
- local: using-diffusers/stable_diffusion_jax_how_to
title: JAX/Flax
- local: optimization/onnx
title: ONNX
- local: optimization/open_vino
title: OpenVINO
- local: optimization/coreml
title: Core ML
title: Optimized model formats
- sections:
- local: optimization/mps
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
title: Optimized hardware
title: Accelerate inference and reduce memory
- sections:
- local: using-diffusers/img2img
title: Image-to-image
- local: using-diffusers/inpaint
title: Inpainting
- local: advanced_inference/outpaint
title: Outpainting
- local: using-diffusers/text-img2vid
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: community_projects
title: Projects built with Diffusers
- local: conceptual/philosophy
title: Philosophy
- local: using-diffusers/controlling_generation
@@ -228,13 +268,11 @@
title: Diffusers' Ethical Guidelines
- local: conceptual/evaluation
title: Evaluating Diffusion Models
title: Conceptual Guides
- sections:
- local: community_projects
title: Projects built with Diffusers
title: Community Projects
- sections:
- isExpanded: false
- title: API
isExpanded: false
sections:
- title: Main Classes
sections:
- local: api/configuration
title: Configuration
@@ -244,8 +282,7 @@
title: Outputs
- local: api/quantization
title: Quantization
title: Main Classes
- isExpanded: false
- title: Loaders
sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
@@ -261,14 +298,14 @@
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
title: Loaders
- isExpanded: false
- title: Models
sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- sections:
- title: ControlNets
sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
@@ -283,8 +320,8 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
title: ControlNets
- sections:
- title: Transformers
sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
@@ -333,6 +370,8 @@
title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
title: SkyReelsV2Transformer3DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
@@ -341,8 +380,8 @@
title: TransformerTemporalModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
- sections:
- title: UNets
sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
@@ -357,8 +396,8 @@
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
title: UNets
- sections:
- title: VAEs
sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -389,9 +428,7 @@
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
title: VAEs
title: Models
- isExpanded: false
- title: Pipelines
sections:
- local: api/pipelines/overview
title: Overview
@@ -527,11 +564,14 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/skyreels_v2
title: SkyReels-V2
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
- title: Stable Diffusion
sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
@@ -568,7 +608,6 @@
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/text_to_video
@@ -587,8 +626,7 @@
title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
- isExpanded: false
- title: Schedulers
sections:
- local: api/schedulers/overview
title: Overview
@@ -658,8 +696,7 @@
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
title: Schedulers
- isExpanded: false
- title: Internal classes
sections:
- local: api/internal_classes_overview
title: Overview
@@ -677,5 +714,3 @@
title: VAE Image Processor
- local: api/video_processor
title: Video Processor
title: Internal classes
title: API
+7 -2
View File
@@ -26,6 +26,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
@@ -92,6 +93,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
## SkyReelsV2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
@@ -100,6 +105,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## WanLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SkyReelsV2Transformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.
The model can be loaded with the following code snippet.
```python
from diffusers import SkyReelsV2Transformer3DModel
transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SkyReelsV2Transformer3DModel
[[autodoc]] SkyReelsV2Transformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+1 -1
View File
@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enabe_model_cpu_offload()
pipe.enable_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
+367
View File
@@ -0,0 +1,367 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# SkyReels-V2: Infinite-length Film Generative model
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team.
*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
You can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization.
The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers)
- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers)
- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers)
- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers)
- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers)
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
### A _Visual_ Demonstration
An example with these parameters:
base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5
vae_scale_factor_temporal -> 4
num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each
base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 → blocks = 25//5 = 5 blocks
This 5 blocks means the maximum context length of the model is 25 frames in the latent space.
Asynchronous Processing Timeline:
┌─────────────────────────────────────────────────────────────────┐
│ Steps: 1 6 11 16 21 26 31 36 41 46 50 │
│ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
└─────────────────────────────────────────────────────────────────┘
For Long Videos (num_frames > base_num_frames):
base_num_frames acts as the "sliding window size" for processing long videos.
Example: 257-frame video with base_num_frames=97, overlap_history=17
┌──── Iteration 1 (frames 1-97) ────┐
│ Processing window: 97 frames │ → 5 blocks, async processing
│ Generates: frames 1-97 │
└───────────────────────────────────┘
┌────── Iteration 2 (frames 81-177) ──────┐
│ Processing window: 97 frames │
│ Overlap: 17 frames (81-97) from prev │ → 5 blocks, async processing
│ Generates: frames 98-177 │
└─────────────────────────────────────────┘
┌────── Iteration 3 (frames 161-257) ──────┐
│ Processing window: 97 frames │
│ Overlap: 17 frames (161-177) from prev │ → 5 blocks, async processing
│ Generates: frames 178-257 │
└──────────────────────────────────────────┘
Each iteration independently runs the asynchronous processing with its own 5 blocks.
base_num_frames controls:
1. Memory usage (larger window = more VRAM)
2. Model context length (must match training constraints)
3. Number of blocks per iteration (base_num_latent_frames // causal_block_size)
Each block takes 30 steps to complete denoising.
Block N starts at step: 1 + (N-1) x ar_step
Total steps: 30 + (5-1) x 5 = 50 steps
Synchronous mode (ar_step=0) would process all blocks/frames simultaneously:
┌──────────────────────────────────────────────┐
│ Steps: 1 ... 30 │
│ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
└──────────────────────────────────────────────┘
Total steps: 30 steps
An example on how the step matrix is constructed for asynchronous processing:
Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5)
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
The algorithm creates a 50x25 step_matrix where:
- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999]
- ...
- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901]
- ...
- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773]
- ...
- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216]
Detailed Row 6 Analysis:
- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999]
- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0]
- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False]
- valid_interval[5]: (0, 25)
Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
### Text-to-Video Generation
The example below demonstrates how to generate a video from text.
<hfoptions id="T2V usage">
<hfoption id="T2V memory">
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
From the original repo:
>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
vae=vae,
transformer=transformer,
torch_dtype=torch.bfloat16
)
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline = pipeline.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipeline(
prompt=prompt,
num_inference_steps=30,
height=544, # 720 for 720P
width=960, # 1280 for 720P
num_frames=97,
base_num_frames=97, # 121 for 720P
ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
causal_block_size=5, # Number of frames in each block for asynchronous processing
overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
addnoise_condition=20, # Improves consistency in long video generation
).frames[0]
export_to_video(output, "T2V.mp4", fps=24, quality=8)
```
</hfoption>
</hfoptions>
### First-Last-Frame-to-Video Generation
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
<hfoptions id="FLF2V usage">
<hfoption id="usage">
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipeline, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipeline)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
```
</hfoption>
</hfoptions>
### Video-to-Video Generation
<hfoptions id="V2V usage">
<hfoption id="usage">
`SkyReelsV2DiffusionForcingVideoToVideoPipeline` extends a given video.
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_video
model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
video = load_video("input_video.mp4")
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
# Total frames will be the number of frames of given video + 257
```
</hfoption>
</hfoptions>
## Notes
- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
<details>
<summary>Show example code</summary>
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
pipeline.set_adapters("steamboat-willie")
pipeline.enable_model_cpu_offload()
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
output = pipeline(
prompt=prompt,
num_frames=97,
guidance_scale=6.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24)
```
</details>
## SkyReelsV2DiffusionForcingPipeline
[[autodoc]] SkyReelsV2DiffusionForcingPipeline
- all
- __call__
## SkyReelsV2DiffusionForcingImageToVideoPipeline
[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline
- all
- __call__
## SkyReelsV2DiffusionForcingVideoToVideoPipeline
[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline
- all
- __call__
## SkyReelsV2Pipeline
[[autodoc]] SkyReelsV2Pipeline
- all
- __call__
## SkyReelsV2ImageToVideoPipeline
[[autodoc]] SkyReelsV2ImageToVideoPipeline
- all
- __call__
## SkyReelsV2PipelineOutput
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
@@ -0,0 +1,316 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`AutoPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that automatically selects which sub-blocks to run based on the inputs provided at runtime, creating conditional workflows that adapt to different scenarios. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use.
In this tutorial, we will show you how to create an `AutoPipelineBlocks` and learn more about how the conditional selection works.
<Tip>
Other types of multi-blocks include [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `AutoPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided.
Let's see an example. We'll use the helper function from the [PipelineBlock guide](./pipeline_block.md) to create our blocks:
**Helper Function**
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
Now let's create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines.
```py
from diffusers.modular_pipelines import AutoPipelineBlocks
# These are dummy blocks and we only focus on "inputs" for our purpose
inputs = [InputParam(name="prompt")]
# block_fn prints out which workflow is running so we can see the execution order at runtime
block_fn = lambda x, y: print("running the text-to-image workflow")
block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!")
inputs = [InputParam(name="prompt"), InputParam(name="image")]
block_fn = lambda x, y: print("running the image-to-image workflow")
block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!")
inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
block_fn = lambda x, y: print("running the inpaint workflow")
block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!")
class AutoImageBlocks(AutoPipelineBlocks):
# List of sub-block classes to choose from
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
# Names for each block in the same order
block_names = ["inpaint", "img2img", "text2img"]
# Trigger inputs that determine which block to run
# - "mask" triggers inpaint workflow
# - "image" triggers img2img workflow (but only if mask is not provided)
# - if none of above, runs the text2img workflow (default)
block_trigger_inputs = ["mask", "image", None]
# Description is extremely important for AutoPipelineBlocks
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ " - inpaint workflow is run when `mask` is provided.\n"
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
)
# Create the blocks
auto_blocks = AutoImageBlocks()
# convert to pipeline
auto_pipeline = auto_blocks.init_pipeline()
```
Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime.
```py
AutoImageBlocks(
Class: AutoPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['mask', 'image']
====================================================================================================
Description: Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
Sub-Blocks:
inpaint [trigger: mask] (TestBlock)
Description: I'm a inpaint workflow!
img2img [trigger: image] (TestBlock)
Description: I'm a image-to-image workflow!
text2img [default] (TestBlock)
Description: I'm a text-to-image workflow!
)
```
Check out the documentation with `print(auto_pipeline.doc)`:
```py
>>> print(auto_pipeline.doc)
class AutoImageBlocks
Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
Inputs:
prompt (`None`, *optional*):
image (`None`, *optional*):
mask (`None`, *optional*):
```
There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right?
This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks.
Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running:
```py
>>> _ = auto_pipeline(image="image", mask="mask")
running the inpaint workflow
>>> _ = auto_pipeline(image="image")
running the image-to-image workflow
>>> _ = auto_pipeline(prompt="prompt")
running the text-to-image workflow
>>> _ = auto_pipeline(image="prompt", mask="mask")
running the inpaint workflow
```
However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines.
Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
inputs = [InputParam(name="ip_adapter_image")]
block_fn = lambda x, y: print("running the ip-adapter workflow")
block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!")
class AutoIPAdapter(AutoPipelineBlocks):
block_classes = [block_ipa_cls]
block_names = ["ip-adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."
```
Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`:
```py
auto_ipa_blocks = AutoIPAdapter()
blocks_dict = InsertableDict()
blocks_dict["ip-adapter"] = auto_ipa_blocks
blocks_dict["image-generation"] = auto_blocks
all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
pipeline = all_blocks.init_pipeline()
```
Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases.
```py
>>> all_blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['image', 'mask', 'ip_adapter_image']
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`).
====================================================================================================
Description:
Sub-Blocks:
[0] ip-adapter (AutoIPAdapter)
Description: Run IP Adapter step if `ip_adapter_image` is provided.
[1] image-generation (AutoImageBlocks)
Description: Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
)
```
This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs.
Let's try some examples:
`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block.
```py
>>> all_blocks.get_execution_blocks('mask')
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Sub-Blocks:
[0] image-generation (TestBlock)
Description: I'm a inpaint workflow!
)
```
Let's also actually run the pipeline to confirm:
```py
>>> _ = pipeline(mask="mask")
skipping auto block: AutoIPAdapter
running the inpaint workflow
```
Try a few more:
```py
print(f"inputs: ip_adapter_image:")
blocks_select = all_blocks.get_execution_blocks('ip_adapter_image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt")
# expect to see ip-adapter + text2img
print(f"inputs: image:")
blocks_select = all_blocks.get_execution_blocks('image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(image="image", prompt="prompt")
# expect to see img2img
print(f"inputs: prompt:")
blocks_select = all_blocks.get_execution_blocks('prompt')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(prompt="prompt")
# expect to see text2img (prompt is not a trigger input so fallback to default)
print(f"inputs: mask + ip_adapter_image:")
blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image")
# expect to see ip-adapter + inpaint
```
In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions.
@@ -0,0 +1,514 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Components Manager
<Tip warning={true}>
🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows.
## Basic Operations
Let's start with the most basic operations. First, create a Components Manager:
```py
from diffusers import ComponentsManager
comp = ComponentsManager()
```
Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function):
```py
from diffusers import AutoModel
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
# Returns component_id like 'text_encoder_139917733042864'
component_id = comp.add("text_encoder", text_encoder)
```
You can view all registered components and their metadata:
```py
>>> comp
Components:
===============================================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-----------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
And remove components using their unique ID:
```py
comp.remove("text_encoder_139917733042864")
```
## Duplicate Detection
The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice.
When you try to add the same object twice, the manager will warn you and return the existing ID:
```py
>>> comp.add("text_encoder", text_encoder)
'text_encoder_139917733042864'
>>> comp.add("text_encoder", text_encoder)
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864'
'text_encoder_139917733042864'
```
Even if you add the same object under a different name, it will still be detected as a duplicate:
```py
>>> comp.add("clip", text_encoder)
ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864'
To remove a duplicate, call `components_manager.remove('<component_id>')`.
'clip_139917733042864'
```
However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example:
```py
>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
>>> comp.add("text_encoder", text_encoder_2)
'text_encoder_139917732983664'
```
This creates a problem - you now have two copies of the same model consuming double the memory:
```py
>>> comp
Components:
===============================================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-----------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint:
```py
from diffusers import ComponentSpec, ComponentsManager
from transformers import CLIPTextModel
comp = ComponentsManager()
# Create ComponentSpec for the first text encoder
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder)
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
# Load and add both components - the manager will detect they're the same model
comp.add("text_encoder", spec.load())
comp.add("text_encoder_duplicated", spec_duplicated.load())
```
Now the manager detects the duplicate and warns you:
```out
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
'text_encoder_duplicated_139917580682672'
```
Both models now show the same `load_id`, making it clear they're the same model:
```py
>>> comp
Components:
======================================================================================================================================================================================================
Models:
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
## Collections
Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed.
Here's how collections work in practice:
```py
comp = ComponentsManager()
# Create ComponentSpec for the first UNet (SDXL base)
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
# Create ComponentSpec for a different UNet (Juggernaut-XL)
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
# Add both UNets to the same collection - the second one will replace the first
comp.add("unet", spec.load(), collection="sdxl")
comp.add("unet", spec2.load(), collection="sdxl")
```
The manager automatically removes the old UNet and adds the new one:
```out
ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888
'unet_139917723893136'
```
Only one UNet remains in the collection:
```py
>>> comp
Components:
====================================================================================================================================================================
Models:
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed.
## Retrieving Components
The Components Manager provides several methods to retrieve registered components.
The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use:
- exact matches like `comp.get_one(name="unet")`
- wildcards like `comp.get_one(name="unet*")` for components starting with "unet"
- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet"
- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae".
Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error.
Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`.
```py
# Get components by name list
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
# Returns: {"text_encoder": component1, "unet": component2, "vae": component3}
```
## Using Components Manager with Modular Pipelines
The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter:
```py
from diffusers import ModularPipeline, ComponentsManager
comp = ComponentsManager()
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
```
By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty:
```py
>>> comp
Components:
==================================================
No components registered.
==================================================
```
When you load components on the pipeline, they are automatically registered in the Components Manager:
```py
>>> pipe.load_components(names="unet")
>>> comp
Components:
==============================================================================================================================================================
Models:
--------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
--------------------------------------------------------------------------------------------------------------------------------------------------------------
unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
--------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection:
```py
# Load all default components
>>> pipe.load_default_components()
# Create a second pipeline using the same Components Manager but with a different collection
>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
```
As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager:
```py
# Get the list of components that pipe2 needs to load
>>> pipe2.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
# Retrieve all required components from the Components Manager
>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
# Update the pipeline with the retrieved components
>>> pipe2.update_components(**comp_dict)
```
The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates:
```out
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400'
ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424'
ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504'
ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904'
ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304'
ComponentsManager: component 'unet' already exists as 'unet_139917580609632'
ComponentsManager: component 'vae' already exists as 'vae_139917722459040'
ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408'
ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432'
```
The pipeline is now fully loaded:
```py
# null_component_names return empty list, meaning everything are loaded
>>> pipe2.null_component_names
[]
```
No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines:
```py
>>> comp
Components:
========================================================================================================================================================================================
Models:
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1
| | | | | | test2
text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1
| | | | | | test2
unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
| | | | | | test2
controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1
| | | | | | test2
vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1
| | | | | | test2
image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1
| | | | | | test2
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Other Components:
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
ID | Class | Collection
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
tokenizer_139917580599504 | CLIPTokenizer | test1
| | test2
scheduler_139916266559408 | EulerDiscreteScheduler | test1
| | test2
tokenizer_2_139915763443904 | CLIPTokenizer | test1
| | test2
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
## Automatic Memory Management
The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them:
```py
comp.enable_auto_cpu_offload(device="cuda")
```
When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow.
## Practical Example: Building Modular Workflows with Component Reuse
Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components.
In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline.
Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images.
```py
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
# Create modular blocks and separate text encoding and decoding steps
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
```
Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection
Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows.
```py
from diffusers import ComponentsManager, ModularPipeline
# Set up Components Manager with auto offloading
components = ComponentsManager()
components.enable_auto_cpu_offload(device="cuda")
# Create a new pipeline to load the components
t2i_repo = "YiYiXu/modular-demo-auto"
t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
# convert the 3 blocks into pipelines and attach the same components manager to all 3
text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
```
Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection:
```py
# Load all components (including IP-Adapter and ControlNet for later use)
t2i_loader_pipe.load_default_components(torch_dtype=torch.float16)
```
Now distribute the loaded components to each pipeline:
```py
# Get VAE for decoder (using get_one since there's only one)
vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
decoder_node.update_components(vae=vae)
# Get text components for text node (using get_components_by_names for multiple components)
text_components = components.get_components_by_names(text_node.null_component_names)
text_node.update_components(**text_components)
# Get remaining components for t2i pipeline
t2i_components = components.get_components_by_names(t2i_pipe.null_component_names)
t2i_pipe.update_components(**t2i_components)
```
Now we can generate images using our modular workflow:
```py
# Generate text embeddings
prompt = "an astronaut"
text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
# Generate latents and decode to image
generator = torch.Generator(device="cuda").manual_seed(0)
latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_t2i, output="images")[0]
image.save("modular_part2_t2i.png")
```
Let's add a LoRA:
```py
# Load LoRA weights
>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
>>> components
Components:
============================================================================================================================================================
...
Additional Component Info:
==================================================
unet:
Adapters: ['toy_face']
```
You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings.
```py
# Generate with LoRA (reusing existing text embeddings)
generator = torch.Generator(device="cuda").manual_seed(0)
latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_lora, output="images")[0]
image.save("modular_part2_lora.png")
```
Now let's create a refiner pipeline that reuses components from our text-to-image workflow:
```py
# Create refiner blocks (removing image_encoder and decode since we work with latents)
refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
refiner_blocks.sub_blocks.pop("image_encoder")
refiner_blocks.sub_blocks.pop("decode")
# Create refiner pipeline with different repo and collection,
# Attach the same component manager to it
refiner_repo = "YiYiXu/modular_refiner"
refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
```
We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection.
```py
# Load only the refiner UNet (different from t2i UNet)
refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
# Reuse components from t2i pipeline using pattern matching
reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
refiner_pipe.update_components(**reuse_components)
```
When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows.
Now we can refine any of our generated latents:
```py
# Refine all our different latents
refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_t2i_refine_out.png")
refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_lora_refine_out.png")
```
Here are the results from our modular pipeline examples.
#### Base Text-to-Image Generation
| Base Text-to-Image | Base Text-to-Image (Refined) |
|-------------------|------------------------------|
| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) |
#### LoRA
| LoRA | LoRA (Refined) |
|-------------------|------------------------------|
| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) |
@@ -0,0 +1,648 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# End-to-End Developer Guide: Building with Modular Diffusers
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
1. **Start with an existing pipeline as a base**
- Identify which existing pipeline is most similar to the one you want to implement
- Determine what part of the pipeline needs modification
2. **Build a working pipeline structure first**
- Assemble the complete pipeline structure
- Use existing blocks wherever possible
- For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
3. **Set up an example**
- Create a simple inference script with expected inputs/outputs
4. **Implement your custom logic and test incrementally**
- Add the custom logics the blocks you want to change
- Test incrementally, and inspect pipeline states and debug as needed
Let's see how this works with the Differential Diffusion example.
## Differential Diffusion Pipeline
### Start with an existing pipeline
Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
>>> IMAGE2IMAGE_BLOCKS = InsertableDict([
... ("text_encoder", StableDiffusionXLTextEncoderStep),
... ("image_encoder", StableDiffusionXLVaeEncoderStep),
... ("input", StableDiffusionXLInputStep),
... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
... ("denoise", StableDiffusionXLDenoiseStep),
... ("decode", StableDiffusionXLDecodeStep)
... ])
```
Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
```py
>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
>>> print(denoise_blocks)
```
```out
StableDiffusionXLDenoiseStep(
Class: StableDiffusionXLDenoiseLoopWrapper
Description: Denoise step that iteratively denoise the latents.
Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
At each iteration, it runs blocks defined in `sub_blocks` sequencially:
- `StableDiffusionXLLoopBeforeDenoiser`
- `StableDiffusionXLLoopDenoiser`
- `StableDiffusionXLLoopAfterDenoiser`
This block supports both text2img and img2img tasks.
Components:
scheduler (`EulerDiscreteScheduler`)
guider (`ClassifierFreeGuidance`)
unet (`UNet2DConditionModel`)
Sub-Blocks:
[0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
[1] denoiser (StableDiffusionXLLoopDenoiser)
Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
[2] after_denoiser (StableDiffusionXLLoopAfterDenoiser)
Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
)
```
Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
Therefore, the key differences when it comes to pipeline implementation would be:
1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
2. The `denoise` step (which selectively applies denoising based on the change map)
3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
Here's a flowchart showing the pipeline structure and the changes we need to make:
![DiffDiff Pipeline Structure](https://mermaid.ink/img/pako:eNqVVO9r4kAQ_VeWLQWFKEk00eRDwZpa7Q-ucPfpYpE1mdWlcTdsVmpb-7_fZk1tTCl3J0Sy8968N5kZ9g0nIgUc4pUk-Rr9iuYc6d_Ibs14vlXoQYpNrtqo07lAo1jBTi2AlynysWIa6DJmG7KCBnZpsHHMSqkqNjaxKC5ALRTbQKEgLyosMthVnEvIiYRFRhRwVaBoNpmUT0W7MrTJkUbSdJEInlbwxMDXcQpcsAKq6OH_2mDTODIY4yt0J0ReUaYGnLXiJVChdSsB-enfPhBnhnjT-rCQj-1K_8Ygt62YUAVy8Ykf4FvU6XYu9rpuIGqPpvXSzs_RVEj2KrgiGUp02zNQTHBEM_FcK3BfQbBHd7qAst-PxvW-9WOrypnNylG0G9oRUMYBFeolg-IQTTJSFDqOUkZp-fwsQURZloVnlPpLf2kVSoonCM-SwCUuqY6dZ5aqddjLd1YiMiFLNrWorrxj9EOmP4El37lsl_9p5PzFqIqwVwgdN981fDM94bphH5I06R8NXZ_4QcPQPTFs6JltPrS6JssFhw9N817l27bdyM-lSKAo6iVBAAnQY0n9wLO9wbcluY7ruUFDtdguH74K0yENKDkK-8nAG6TfNrfy_bf-HjdrlOfZS7VYSAlU5JAwyhLE9WrWVw1dWdPTXauDsy8LUkdHtnX_pfMnBOvSGluRNbGurbuTHtdZN9Zts1MljC19_7EUh0puwcIbkBtSHvFbic6xWsMG5jjUrymRT3M85-86Jyf8txCbjzQptqs1DinJCn3a5qm-viJG9M26OUYlcH0_jsWWKxwGttHA4Rve4dD1el3H8_yh49hD3_X7roVfcNhx-l3b14PxvGHQ0xMa9t4t_Gp8na7tDvu-4w08HXecweD9D4X54ZI)
### Build a Working Pipeline Structure
ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
```py
>>> # Copy existing blocks as placeholders
>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
...
>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later"""
... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
```
`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
```py
>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
```
Now we can put together our differential diffusion pipeline.
```py
>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
>>>
>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
>>> print(dd_blocks)
>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
```
### Set up an example
ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
```py
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
>>> dd_pipeline.to("cuda")
```
We will use this example script:
```py
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
>>>
>>> image.save("diffdiff_out.png")
```
If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
and you would get the same result as the original img2img pipeline.
### implement your custom logic and test incrementally
Let's modify the pipeline so that we can get expected result with this example script.
We'll start with the `prepare_latents` step. The main changes are:
- Requires a new user input `diffdiff_map`
- Requires new component `mask_processor` to process the `diffdiff_map`
- Requires new intermediate inputs:
- Need `timestep` instead of `latent_timestep` to precompute all the latents
- Need `num_inference_steps` to create the `diffdiff_masks`
- create a new output `diffdiff_masks` and `original_latents`
<Tip>
💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline.
e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input.
</Tip>
Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
<Tip>
💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does.
</Tip>
Here are the key changes we made to implement differential diffusion:
**1. Modified `prepare_latents` step:**
```diff
class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
+ InputParam("diffdiff_map", required=True),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
+ InputParam("timesteps", type_hint=torch.Tensor),
+ InputParam("num_inference_steps", type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
+ OutputParam("original_latents", type_hint=torch.Tensor),
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, state: PipelineState):
# ... existing logic ...
+ # Process change map and create masks
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
+ block_state.original_latents = block_state.latents
```
**2. Modified `before_denoiser` step:**
```diff
class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
@property
def description(self) -> str:
return (
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
)
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("denoising_start"),
+ ]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
+ InputParam("original_latents", type_hint=torch.Tensor),
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, block_state, i, t):
+ # Apply differential diffusion logic
+ if i == 0 and block_state.denoising_start is None:
+ block_state.latents = block_state.original_latents[:1]
+ else:
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
# ... rest of existing logic ...
```
That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
Now we use the process we've prepred in step2 to build the pipeline and inspect it.
```py
>> dd_pipeline
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Components:
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
guider (`ClassifierFreeGuidance`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
scheduler (`EulerDiscreteScheduler`)
mask_processor (`VaeImageProcessor`)
unet (`UNet2DConditionModel`)
Configs:
force_zeros_for_empty_prompt (default: True)
requires_aesthetics_score (default: False)
Blocks:
[0] text_encoder (StableDiffusionXLTextEncoderStep)
Description: Text Encoder step that generate text_embeddings to guide the image generation
[1] image_encoder (StableDiffusionXLVaeEncoderStep)
Description: Vae Encoder step that encode the input image into a latent representation
[2] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[3] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[4] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
Description: Step that prepares the latents for the differential diffusion generation process
[5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
[6] denoise (SDXLDiffDiffDenoiseStep)
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
[7] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
```
Run the example now, you should see an apple with its right half transformed into a green pear.
![Image description](https://cdn-uploads.huggingface.co/production/uploads/624ef9ba9d608e459387b34e/4zqJOz-35Q0i6jyUW3liL.png)
## Adding IP-adapter
We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md)
We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
```
We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
```py
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
```
Take a look at the new diff-diff pipeline with ip-adapter!
```py
>>> print(dd_blocks)
```
The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor`
```out
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: {'ip_adapter_image'}
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
====================================================================================================
Description:
Components:
image_encoder (`CLIPVisionModelWithProjection`)
feature_extractor (`CLIPImageProcessor`)
unet (`UNet2DConditionModel`)
guider (`ClassifierFreeGuidance`)
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
scheduler (`EulerDiscreteScheduler`)
mask_processor (`VaeImageProcessor`)
Configs:
force_zeros_for_empty_prompt (default: True)
requires_aesthetics_score (default: False)
Blocks:
[0] ip_adapter (StableDiffusionXLAutoIPAdapterStep)
Description: Run IP Adapter step if `ip_adapter_image` is provided.
[1] text_encoder (StableDiffusionXLTextEncoderStep)
Description: Text Encoder step that generate text_embeddings to guide the image generation
[2] image_encoder (StableDiffusionXLVaeEncoderStep)
Description: Vae Encoder step that encode the input image into a latent representation
[3] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[4] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[5] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
Description: Step that prepares the latents for the differential diffusion generation process
[6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
[7] denoise (SDXLDiffDiffDenoiseStep)
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
[8] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
```
Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output.
```py
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
>>>
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
>>> dd_pipeline.loader.set_ip_adapter_scale(0.6)
>>> dd_pipeline = dd_pipeline.to(device)
>>>
>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>> generator = torch.Generator(device=device).manual_seed(42)
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... generator=generator,
... ip_adapter_image=ip_adapter_image,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
```
## Working with ControlNets
What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are:
1. A ControlNet input step that prepares the control condition
2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet
From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts.
Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step.
```py
>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
>>>
>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
>>> # print(controlnet_denoise)
```
We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
>>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
>>> print(control_input_block)
```
```out
StableDiffusionXLAutoControlNetInputStep(
Class: AutoPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['control_image', 'control_mode']
====================================================================================================
Description: Controlnet Input step that prepare the controlnet input.
This is an auto pipeline block that works for both controlnet and controlnet_union.
(it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
Components:
controlnet (`ControlNetUnionModel`)
control_image_processor (`VaeImageProcessor`)
Sub-Blocks:
• controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
Description: step that prepares inputs for the ControlNetUnion model
• controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep)
Description: step that prepare inputs for controlnet
)
```
Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
```py
>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
>>>
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
>>> dd_pipeline = dd_pipeline.to(device)
>>>
>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>> generator = torch.Generator(device=device).manual_seed(42)
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... generator=generator,
... control_image=control_image,
... controlnet_conditioning_scale=0.5,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
```
Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
```py
>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
... block_names = ["controlnet_denoise", "denoise"]
... block_trigger_inputs = ["controlnet_cond", None]
```
`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step.
<Tip>
Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
</Tip>
Now you can create the differential diffusion preset that works with ip-adapter & controlnet.
```py
>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
>>>
>>> print(DIFFDIFF_AUTO_BLOCKS)
```
to use
```py
>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
>>> dd_pipeline = dd_auto_blocks.init_pipeline(...)
```
## Creating a Modular Repo
You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff
To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily.
```py
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
```
With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared.
```py
>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
>>> import torch
>>> from diffusers.utils import load_image
>>>
>>> repo_id = "YiYiXu/modular-diffdiff-0704"
>>>
>>> components = ComponentsManager()
>>>
>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff")
>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
>>> components.enable_auto_cpu_offload()
```
see more usage example on model card.
## deploy a mellon node
[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon]
@@ -0,0 +1,194 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LoopSequentialPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks.
<Tip>
Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
You could create a loop using `PipelineBlock` like this:
```python
class DenoiseLoop(PipelineBlock):
def __call__(self, components, state):
block_state = self.get_block_state(state)
for t in range(block_state.num_inference_steps):
# ... loop logic here
pass
self.set_block_state(state, block_state)
return components, state
```
But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations.
It involves two parts: a **loop wrapper** and **loop blocks**
* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar.
* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper.
- they run sequentially for each iteration of the loop
- they receive the current iteration index as an additional parameter
- they share the same block_state throughout the entire loop
Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations.
We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps:
1. defining the loop wrapper class
2. creating the loop blocks
3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance
**Step 1: Define the Loop Wrapper**
To create a `LoopSequentialPipelineBlocks` class, you need to define:
* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`)
* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`)
* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`)
* `__call__` method: Defines the loop structure and iteration logic
Here is an example of a loop wrapper:
```py
import torch
from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam
class LoopWrapper(LoopSequentialPipelineBlocks):
model_name = "test"
@property
def description(self):
return "I'm a loop!!"
@property
def loop_inputs(self):
return [InputParam(name="num_steps")]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
# Loop structure - can be customized to your needs
for i in range(block_state.num_steps):
# loop_step executes all registered blocks in sequence
components, block_state = self.loop_step(components, block_state, i=i)
self.set_block_state(state, block_state)
return components, state
```
**Step 2: Create Loop Blocks**
Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently:
* It receives the iteration variable (e.g., `i`) passed by the loop wrapper
* It works directly with `block_state` instead of pipeline state
* No need to call `self.get_block_state()` or `self.set_block_state()`
```py
class LoopBlock(PipelineBlock):
# this is used to identify the model family, we won't worry about it in this example
model_name = "test"
@property
def inputs(self):
return [InputParam(name="x")]
@property
def intermediate_outputs(self):
# outputs produced by this block
return [OutputParam(name="x")]
@property
def description(self):
return "I'm a block used inside the `LoopWrapper` class"
def __call__(self, components, block_state, i: int):
block_state.x += 1
return components, block_state
```
**Step 3: Combine Everything**
Finally, assemble your loop by adding the block(s) to the wrapper:
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
```
Now you've created a loop with one step:
```py
>>> loop
LoopWrapper(
Class: LoopSequentialPipelineBlocks
Description: I'm a loop!!
Sub-Blocks:
[0] block1 (LoopBlock)
Description: I'm a block used inside the `LoopWrapper` class
)
```
It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop.
```py
>>> print(loop.doc)
class LoopWrapper
I'm a loop!!
Inputs:
x (`None`, *optional*):
num_steps (`None`, *optional*):
Outputs:
x (`None`):
```
**Running the Loop:**
```py
# run the loop
loop_pipeline = loop.init_pipeline()
x = loop_pipeline(num_steps=10, x=0, output="x")
assert x == 10
```
**Adding Multiple Blocks:**
We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration:
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
loop_pipeline = loop.init_pipeline()
x = loop_pipeline(num_steps=10, x=0, output="x")
assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20
```
**Key Differences from SequentialPipelineBlocks:**
The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself.
The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice:
```py
from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep
StableDiffusionXLDenoiseStep()
```
@@ -0,0 +1,59 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PipelineState and BlockState
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system.
In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
A `PipelineState` consists of two distinct states:
- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it.
- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them.
Here's an example of what a `PipelineState` looks like:
```py
PipelineState(
inputs={
'prompt': 'a cat'
'guidance_scale': 7.0
'num_inference_steps': 25
},
intermediates={
'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
'negative_prompt_embeds': None
},
)
```
Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes.
For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain:
```py
BlockState(
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>
)
```
You can access the variables directly as attributes: `block_state.image`.
We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md).
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,42 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Getting Started with Modular Diffusers
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you:
**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.
**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows.
Here's how our guides are organized to help you navigate the Modular Diffusers documentation:
### 🚀 Running Pipelines
- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it
- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines
### 📚 Creating PipelineBlocks
- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState
- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks
- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence
- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows
- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection
### 🎯 Practical Examples
- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes
@@ -0,0 +1,292 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PipelineBlock
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows.
<Tip>
**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md).
</Tip>
In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state.
## PipelineState
Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details.
## Define a `PipelineBlock`
To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce.
The three main properties you need to define are:
- `inputs`: Immutable values from the user that cannot be modified
- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified
- `intermediate_outputs`: New values your block creates for subsequent blocks and user access
Let's explore each one and understand how they work with the pipeline state.
**Inputs: Immutable User Values**
Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`:
```py
user_inputs = [
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
]
```
When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged."
This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency.
**Intermediate Inputs: Mutable Values from Previous Blocks, or Users**
Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case:
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different."
**Intermediate Outputs: New Values for Subsequent Blocks and User Access**
Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes:
1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline
2. **For users**: They become available as final outputs that users can access when running the pipeline
```py
user_intermediate_outputs = [
OutputParam(name="image_latents", description="latents representing the image")
]
```
Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match.
Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps.
**The `__call__` Method Structure**
Your `PipelineBlock`'s `__call__` method should follow this structure:
```py
def __call__(self, components, state):
# Get a local view of the state variables this block needs
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# You can access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
self.set_block_state(state, block_state)
return components, state
```
The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation.
**Components and Configs**
You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`:
```py
from diffusers import ComponentSpec, ConfigSpec
# Define components your block needs
expected_components = [
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
]
# Define pipeline-level configs
expected_config = [
ConfigSpec("force_zeros_for_empty_prompt", True)
]
```
**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md).
**Configs**: Pipeline-level settings that control behavior across all blocks.
When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation:
```py
def __call__(self, components, state):
# Access components using dot notation
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks
Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that.
**Helper Function**
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
## Example: Creating a Simple Pipeline Block
Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them:
```py
inputs = [
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
]
intermediate_inputs = [InputParam(name="batch_size", type_hint=int)]
intermediate_outputs = [
OutputParam(name="image_latents", description="latents representing the image")
]
def image_encoder_block_fn(block_state, pipeline_state):
print(f"pipeline_state (before update): {pipeline_state}")
print(f"block_state (before update): {block_state}")
# Simulate processing the image
block_state.image = torch.randn(1, 3, 512, 512)
block_state.batch_size = block_state.batch_size * 2
block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size
block_state.image_latents = torch.randn(1, 4, 64, 64)
print(f"block_state (after update): {block_state}")
return block_state
# Create a block with our definitions
image_encoder_block_cls = make_block(
inputs=inputs,
intermediate_inputs=intermediate_inputs,
intermediate_outputs=intermediate_outputs,
block_fn=image_encoder_block_fn,
description="Encode raw image into its latent presentation"
)
image_encoder_block = image_encoder_block_cls()
pipe = image_encoder_block.init_pipeline()
```
Let's check the pipeline's docstring to see what inputs it expects:
```py
>>> print(pipe.doc)
class TestBlock
Encode raw image into its latent presentation
Inputs:
image (`PIL.Image`, *optional*):
raw input image to process
batch_size (`int`, *optional*):
Outputs:
image_latents (`None`):
latents representing the image
```
Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state.
Now let's run the pipeline:
```py
from diffusers.utils import load_image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png")
state = pipe(image=image, batch_size=2)
print(f"pipeline_state (after update): {state}")
```
```out
pipeline_state (before update): PipelineState(
inputs={
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
},
intermediates={
batch_size: 2
},
)
block_state (before update): BlockState(
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>
batch_size: 2
)
block_state (after update): BlockState(
image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512]))
batch_size: 4
processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])]
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
)
pipeline_state (after update): PipelineState(
inputs={
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
},
intermediates={
batch_size: 4
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
},
)
```
**Key Observations:**
1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`.
2. **After the update**:
- **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only.
- **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict)
- **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output
- **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output
@@ -0,0 +1,189 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SequentialPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks.
<Tip>
Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together.
The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks.
Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks.
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
Let's create a block that produces `batch_size`, which we'll call "input_block":
```py
def input_block_fn(block_state, pipeline_state):
batch_size = len(block_state.prompt)
block_state.batch_size = batch_size * block_state.num_images_per_prompt
return block_state
input_block_cls = make_block(
inputs=[
InputParam(name="prompt", type_hint=list, description="list of text prompts"),
InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt")
],
intermediate_outputs=[
OutputParam(name="batch_size", description="calculated batch size")
],
block_fn=input_block_fn,
description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
)
input_block = input_block_cls()
```
Now let's create a second block that uses the `batch_size` from the first block:
```py
def image_encoder_block_fn(block_state, pipeline_state):
# Simulate processing the image
block_state.image = torch.randn(1, 3, 512, 512)
block_state.batch_size = block_state.batch_size * 2
block_state.image_latents = torch.randn(1, 4, 64, 64)
return block_state
image_encoder_block_cls = make_block(
inputs=[
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
],
intermediate_inputs=[
InputParam(name="batch_size", type_hint=int)
],
intermediate_outputs=[
OutputParam(name="image_latents", description="latents representing the image")
],
block_fn=image_encoder_block_fn,
description="Encode raw image into its latent presentation"
)
image_encoder_block = image_encoder_block_cls()
```
Now let's connect these blocks to create a `SequentialPipelineBlocks`:
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
# Define a dict mapping block names to block instances
blocks_dict = InsertableDict()
blocks_dict["input"] = input_block
blocks_dict["image_encoder"] = image_encoder_block
# Create the SequentialPipelineBlocks
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
```
Now you have a `SequentialPipelineBlocks` with 2 blocks:
```py
>>> blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Sub-Blocks:
[0] input (TestBlock)
Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.
[1] image_encoder (TestBlock)
Description: Encode raw image into its latent presentation
)
```
When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it.
```py
>>> print(blocks.doc)
class SequentialPipelineBlocks
Inputs:
prompt (`None`, *optional*):
num_images_per_prompt (`None`, *optional*):
image (`PIL.Image`, *optional*):
raw input image to process
Outputs:
batch_size (`None`):
image_latents (`None`):
latents representing the image
```
At runtime, you have data flow like this:
![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png)
**How SequentialPipelineBlocks Works:**
1. Blocks are executed in the order they're registered in the `blocks_dict`
2. Outputs from one block become available as intermediate inputs to all subsequent blocks
3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks
4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces
What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs.
+25 -14
View File
@@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 810x.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
# compile only the repeated transformer layers inside the UNet
pipeline.unet.compile_repeated_blocks(fullgraph=True)
```
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
> [!TIP]
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
# pip install -U accelerate
@@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
### Graph breaks
@@ -242,6 +239,12 @@ The `step()` function is [called](https://github.com/huggingface/diffusers/blob/
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
<Tip>
Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
</Tip>
### Benchmarks
Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) dataset to see inference latency and memory usage data for compiled pipelines.
@@ -296,3 +299,11 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
## Resources
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
- Read the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post
to maximize performance when using `torch.compile`.
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
> [!TIP]
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
@@ -1,23 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Overview
Welcome to 🧨 Diffusers! If you're new to diffusion models and generative AI, and want to learn more, then you've come to the right place. These beginner-friendly tutorials are designed to provide a gentle introduction to diffusion models and help you understand the library fundamentals - the core components and how 🧨 Diffusers is meant to be used.
You'll learn how to use a pipeline for inference to rapidly generate things, and then deconstruct that pipeline to really understand how to use the library as a modular toolbox for building your own diffusion systems. In the next lesson, you'll learn how to train your own diffusion model to generate what you want.
After completing the tutorials, you'll have gained the necessary skills to start exploring the library on your own and see how to use it for your own projects and applications.
Feel free to join our community on [Discord](https://discord.com/invite/JfAtkvEtRb) or the [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) to connect and collaborate with other users and developers!
Let's start diffusing! 🧨
@@ -1,18 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Overview
The inference pipeline supports and enables a wide range of techniques that are divided into two categories:
* Pipeline functionality: these techniques modify the pipeline or extend it for other applications. For example, pipeline callbacks add new features to a pipeline and a pipeline can also be extended for distributed inference.
* Improve inference quality: these techniques increase the visual quality of the generated images. For example, you can enhance your prompts with GPT2 to create better images with lower effort.
@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
args,
instance_data_root,
instance_prompt,
class_prompt,
@@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset):
class_num=None,
size=1024,
repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
@@ -1058,7 +1057,7 @@ class DreamBoothDataset(Dataset):
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
@@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset):
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
y1 = max(0, int(round((image.height - self.size) / 2.0)))
x1 = max(0, int(round((image.width - self.size) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
image = crop(image, y1, x1, h, w)
image = train_transforms(image)
self.pixel_values.append(image)
@@ -1102,7 +1101,7 @@ class DreamBoothDataset(Dataset):
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=interpolation),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -1827,6 +1826,7 @@ def main(args):
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
args=args,
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,7 +1836,6 @@ def main(args):
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
+46 -1
View File
@@ -87,6 +87,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -5479,4 +5480,48 @@ edited_image.save("edited_image.png")
### Note
This model is trained on 512x512, so input size is better on 512x512.
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
Editing at Scale", many thanks to their contribution!
Editing at Scale", many thanks to their contribution!
# Flux Kontext multiple images
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
This pipeline loads two reference images and generates a new image based on them.
```python
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
custom_pipeline="pipeline_flux_kontext_multiple_images",
)
pipe.to("cuda")
pikachu_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
cat_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
).convert("RGB")
prompts = [
"Pikachu and the cat are sitting together at a pizzeria table, enjoying a delicious pizza.",
]
images = pipe(
multiple_images=[(pikachu_image, cat_image)],
prompt=prompts,
guidance_scale=2.5,
generator=torch.Generator().manual_seed(42),
).images
images[0].save("pizzeria.png")
```
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = controlnet_image * vae.config.scaling_factor
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
control_block_res_samples = controlnet(
hidden_states=noisy_model_input,
@@ -1614,7 +1614,7 @@ def main(args):
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
@@ -58,6 +58,7 @@ from diffusers.training_utils import (
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
offload_models,
)
from diffusers.utils import (
check_min_version,
@@ -1364,43 +1365,34 @@ def main(args):
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
validation_embeddings = {}
if args.validation_prompt is not None:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1581,12 +1573,10 @@ def main(args):
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
if args.offload:
vae = vae.to(accelerator.device)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
if args.offload:
vae = vae.to("cpu")
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
+1 -1
View File
@@ -1,4 +1,4 @@
torch~=2.4.0
torch~=2.7.0
transformers==4.46.1
sentencepiece
aiohttp
+28 -20
View File
@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.4.3
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.10.10
aiohttp==3.12.14
# via -r requirements.in
aiosignal==1.3.1
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
# huggingface-hub
# torch
# transformers
# triton
frozenlist==1.5.0
# via
# aiohttp
@@ -63,36 +62,42 @@ networkx==3.2.1
# via torch
numpy==2.0.2
# via transformers
nvidia-cublas-cu12==12.1.3.1
nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-cupti-cu12==12.6.80
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.6.77
# via torch
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.6.77
# via torch
nvidia-cudnn-cu12==9.1.0.70
nvidia-cudnn-cu12==9.5.1.17
# via torch
nvidia-cufft-cu12==11.0.2.54
nvidia-cufft-cu12==11.3.0.4
# via torch
nvidia-curand-cu12==10.3.2.106
nvidia-cufile-cu12==1.11.1.6
# via torch
nvidia-cusolver-cu12==11.4.5.107
nvidia-curand-cu12==10.3.7.77
# via torch
nvidia-cusparse-cu12==12.1.0.106
nvidia-cusolver-cu12==11.7.1.2
# via torch
nvidia-cusparse-cu12==12.5.4.2
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5
nvidia-cusparselt-cu12==0.6.3
# via torch
nvidia-nvjitlink-cu12==12.9.86
nvidia-nccl-cu12==2.26.2
# via torch
nvidia-nvjitlink-cu12==12.6.85
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# torch
nvidia-nvtx-cu12==12.6.77
# via torch
packaging==24.1
# via
@@ -105,7 +110,9 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via yarl
# via
# aiohttp
# yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -137,7 +144,7 @@ sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
torch==2.4.1
torch==2.7.0
# via -r requirements.in
tqdm==4.66.5
# via
@@ -145,10 +152,11 @@ tqdm==4.66.5
# transformers
transformers==4.46.1
# via -r requirements.in
triton==3.0.0
triton==3.3.0
# via torch
typing-extensions==4.12.2
# via
# aiosignal
# anyio
# exceptiongroup
# fastapi
@@ -163,5 +171,5 @@ urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.16.0
yarl==1.18.3
# via aiohttp
+637
View File
@@ -0,0 +1,637 @@
import argparse
import os
import pathlib
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from diffusers import (
AutoencoderKLWan,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
"fps_projection.0": "fps_projection.net.0.proj",
"fps_projection.2": "fps_projection.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def load_sharded_safetensors(dir: pathlib.Path):
if "720P" in str(dir):
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
else:
file_paths = list(dir.glob("model*.safetensors"))
state_dict = {}
for path in file_paths:
state_dict.update(load_file(path))
return state_dict
def get_transformer_config(model_type: str) -> Dict[str, Any]:
if model_type == "SkyReels-V2-DF-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 12,
"inject_sample_info": True,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-DF-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-DF-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-T2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-T2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-T2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-T2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-I2V-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": 1536,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 12,
"inject_sample_info": False,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-I2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-I2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-FLF2V-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": 1536,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 12,
"inject_sample_info": False,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
elif model_type == "SkyReels-V2-FLF2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
elif model_type == "SkyReels-V2-FLF2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
return config
def convert_transformer(model_type: str):
config = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
if "1.3B" in model_type:
original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors"))
else:
os.makedirs(model_type, exist_ok=True)
model_dir = pathlib.Path(model_type)
if "720P" in model_type:
top_shard = 7 if "I2V" in model_type else 6
zeros = "0" * (4 if "I2V" or "T2V" in model_type else 3)
model_name = "diffusion_pytorch_model"
elif "540P" in model_type:
top_shard = 14 if "I2V" in model_type else 12
model_name = "model"
for i in range(1, top_shard + 1):
shard_path = f"{model_name}-{i:05d}-of-{zeros}{top_shard}.safetensors"
hf_hub_download(model_id, shard_path, local_dir=model_dir)
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
if "FLF2V" in model_type:
if (
hasattr(transformer.condition_embedder, "image_embedder")
and hasattr(transformer.condition_embedder.image_embedder, "pos_embed")
and transformer.condition_embedder.image_embedder.pos_embed is not None
):
pos_embed_shape = transformer.condition_embedder.image_embedder.pos_embed.shape
original_state_dict["condition_embedder.image_embedder.pos_embed"] = torch.zeros(pos_embed_shape)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae():
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
new_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])
# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
new_state_dict[key] = value
continue
# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key
new_state_dict[new_key] = value
# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
new_state_dict[new_key] = value
# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value
with init_empty_weights():
vae = AutoencoderKLWan()
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--dtype", default="fp32")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args.model_type).to(dtype=dtype)
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction",
num_train_timesteps=1000,
use_flow_sigmas=True,
)
if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = SkyReelsV2ImageToVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "T2V" in args.model_type:
pipe = SkyReelsV2Pipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
elif "DF" in args.model_type:
pipe = SkyReelsV2DiffusionForcingPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(
args.output_path,
safe_serialization=True,
max_shard_size="5GB",
# push_to_hub=True,
# repo_id=f"<place_holder>/{args.model_type}-Diffusers",
)
+1 -1
View File
@@ -110,7 +110,7 @@ _deps = [
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
"k-diffusion>=0.0.12",
"k-diffusion==0.0.12",
"torchsde",
"note_seq",
"librosa",
+76
View File
@@ -34,10 +34,13 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"modular_pipelines": [],
"pipelines": [],
"quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -130,14 +133,29 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
"TangentialClassifierFreeGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -145,6 +163,7 @@ else:
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
@@ -201,6 +220,7 @@ else:
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
"SkyReelsV2Transformer3DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
@@ -219,6 +239,15 @@ else:
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
]
)
_import_structure["modular_pipelines"].extend(
[
"ComponentsManager",
"ComponentSpec",
"ModularPipeline",
"ModularPipelineBlocks",
]
)
_import_structure["optimization"] = [
@@ -333,6 +362,14 @@ except OptionalDependencyNotAvailable:
]
else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
"WanModularPipeline",
]
)
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
@@ -456,6 +493,11 @@ else:
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingVideoToVideoPipeline",
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
@@ -545,6 +587,7 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -751,18 +794,32 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
@@ -819,6 +876,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
@@ -836,6 +894,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
)
from .modular_pipelines import (
ComponentsManager,
ComponentSpec,
ModularPipeline,
ModularPipelineBlocks,
)
from .optimization import (
get_constant_schedule,
@@ -933,6 +998,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,
WanModularPipeline,
)
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
@@ -1053,6 +1124,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
SkyReelsV2DiffusionForcingImageToVideoPipeline,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2DiffusionForcingVideoToVideoPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
+35
View File
@@ -207,3 +207,38 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
class SD3CFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
pooled_prompt_embeds = pooled_prompt_embeds[
-1:
] # "-1" denotes the embeddings for conditional pooled text tokens.
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
return callback_kwargs
+134
View File
@@ -0,0 +1,134 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage example:
TODO
"""
import ast
import importlib.util
import os
from argparse import ArgumentParser, Namespace
from pathlib import Path
from ..utils import logging
from . import BaseDiffusersCLICommand
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
CONFIG = "config.json"
def conversion_command_factory(args: Namespace):
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
class CustomBlocksCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
conversion_parser = parser.add_parser("custom_blocks")
conversion_parser.add_argument(
"--block_module_name",
type=str,
default="block.py",
help="Module filename in which the custom block will be implemented.",
)
conversion_parser.add_argument(
"--block_class_name",
type=str,
default=None,
help="Name of the custom block. If provided None, we will try to infer it.",
)
conversion_parser.set_defaults(func=conversion_command_factory)
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
self.block_module_name = Path(block_module_name)
self.block_class_name = block_class_name
def run(self):
# determine the block to be saved.
out = self._get_class_names(self.block_module_name)
classes_found = list({cls for cls, _ in out})
if self.block_class_name is not None:
child_class, parent_class = self._choose_block(out, self.block_class_name)
if child_class is None and parent_class is None:
raise ValueError(
"`block_class_name` could not be retrieved. Available classes from "
f"{self.block_module_name}:\n{classes_found}"
)
else:
self.logger.info(
f"Found classes: {classes_found} will be using {classes_found[0]}. "
"If this needs to be changed, re-run the command specifying `block_class_name`."
)
child_class, parent_class = out[0][0], out[0][1]
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
# the user is responsible for running it, so I guess that is safe?
module_name = f"__dynamic__{self.block_module_name.stem}"
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
getattr(module, child_class)().save_pretrained(os.getcwd())
# or, we could create it manually.
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:
if cls == chosen:
return cls, base
return None, None
def _get_class_names(self, file_path):
source = file_path.read_text(encoding="utf-8")
try:
tree = ast.parse(source, filename=file_path)
except SyntaxError as e:
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
results: list[tuple[str, str]] = []
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
# extract all base names for this class
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
# for each allowed base that appears in the class's bases, emit a tuple
for allowed in EXPECTED_PARENT_CLASSES:
if allowed in base_names:
results.append((node.name, allowed))
return results
def _get_base_name(self, node: ast.expr):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
val = self._get_base_name(node.value)
return f"{val}.{node.attr}" if val else node.attr
return None
def _create_automap(self, parent_class, child_class):
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
return {"auto_map": auto_map}
+2
View File
@@ -15,6 +15,7 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -26,6 +27,7 @@ def main():
# Register commands
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
+10 -1
View File
@@ -176,6 +176,7 @@ class ConfigMixin:
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
subfolder = kwargs.pop("subfolder", None)
self._upload_folder(
save_directory,
@@ -183,6 +184,7 @@ class ConfigMixin:
token=token,
commit_message=commit_message,
create_pr=create_pr,
subfolder=subfolder,
)
@classmethod
@@ -601,6 +603,10 @@ class ConfigMixin:
value = value.tolist()
elif isinstance(value, Path):
value = value.as_posix()
elif hasattr(value, "to_dict") and callable(value.to_dict):
value = value.to_dict()
elif isinstance(value, list):
value = [to_json_saveable(v) for v in value]
return value
if "quantization_config" in config_dict:
@@ -757,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
if remapped_class is cls:
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
else:
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
+1 -1
View File
@@ -17,7 +17,7 @@ deps = {
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12",
"k-diffusion": "k-diffusion==0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",
+39
View File
@@ -0,0 +1,39 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
@@ -0,0 +1,188 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_apg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled():
num_conditions += 1
return num_conditions
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
+190
View File
@@ -0,0 +1,190 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AutoGuidance(BaseGuidance):
"""
AutoGuidance: https://huggingface.co/papers/2406.02507
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
self.auto_guidance_config = auto_guidance_config
self.dropout = dropout
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if auto_guidance_layers is None and auto_guidance_config is None:
raise ValueError(
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
)
if auto_guidance_layers is not None and auto_guidance_config is not None:
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
if (dropout is None and auto_guidance_layers is not None) or (
dropout is not None and auto_guidance_layers is None
):
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
if auto_guidance_layers is not None:
if isinstance(auto_guidance_layers, int):
auto_guidance_layers = [auto_guidance_layers]
if not isinstance(auto_guidance_layers, list):
raise ValueError(
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
)
auto_guidance_config = [
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
]
if isinstance(auto_guidance_config, dict):
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
if isinstance(auto_guidance_config, LayerSkipConfig):
auto_guidance_config = [auto_guidance_config]
if not isinstance(auto_guidance_config, list):
raise ValueError(
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
)
elif isinstance(next(iter(auto_guidance_config), None), dict):
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
self.auto_guidance_config = auto_guidance_config
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_ag_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_ag_enabled():
num_conditions += 1
return num_conditions
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -0,0 +1,141 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -0,0 +1,152 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeZeroStarGuidance(BaseGuidance):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
else:
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond_dtype = cond.dtype
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.to(dtype=cond_dtype)
+309
View File
@@ -0,0 +1,309 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from ..configuration_utils import ConfigMixin
from ..utils import PushToHubMixin, get_logger
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
GUIDER_CONFIG_NAME = "guider_config.json"
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseGuidance(ConfigMixin, PushToHubMixin):
r"""Base class providing the skeleton for implementing guidance techniques."""
config_name = GUIDER_CONFIG_NAME
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
if not (start <= stop <= 1.0):
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
)
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
return self.forward(**forward_inputs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
) -> Self:
r"""
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
saved with [`~BaseGuidance.save_pretrained`].
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
<Tip>
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
`huggingface-cli login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
</Tip>
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
return_commit_hash=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a guider configuration object to a directory so that it can be reloaded using the
[`~BaseGuidance.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@@ -0,0 +1,271 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
class PerturbedAttentionGuidance(BaseGuidance):
"""
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
layers.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
and implementation details.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
perturbed_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for perturbed attention guidance.
perturbed_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
perturbed_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
If not provided, `perturbed_guidance_config` must be provided.
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
# for each model architecture.
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
perturbed_guidance_scale: float = 2.8,
perturbed_guidance_start: float = 0.01,
perturbed_guidance_stop: float = 0.2,
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
self.skip_layer_guidance_start = perturbed_guidance_start
self.skip_layer_guidance_stop = perturbed_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if perturbed_guidance_config is None:
if perturbed_guidance_layers is None:
raise ValueError(
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
)
perturbed_guidance_config = LayerSkipConfig(
indices=perturbed_guidance_layers,
fqn="auto",
skip_attention=False,
skip_attention_scores=True,
skip_ff=False,
)
else:
if perturbed_guidance_layers is not None:
raise ValueError(
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
)
if isinstance(perturbed_guidance_config, dict):
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
if isinstance(perturbed_guidance_config, LayerSkipConfig):
perturbed_guidance_config = [perturbed_guidance_config]
if not isinstance(perturbed_guidance_config, list):
raise ValueError(
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
)
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
for config in perturbed_guidance_config:
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
logger.warning(
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
"Please check your configuration. Modifying the config to match the expected values."
)
config.skip_attention = False
config.skip_attention_scores = True
config.skip_ff = False
self.skip_layer_config = perturbed_guidance_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -0,0 +1,262 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SkipLayerGuidance(BaseGuidance):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, dict):
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
elif isinstance(next(iter(skip_layer_config), None), dict):
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -0,0 +1,251 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
future without warning or guarantee of reproducibility. This implementation assumes:
- Generated images are square (height == width)
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
together such as Flux)
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
seg_guidance_scale (`float`, defaults to `3.0`):
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
seg_blur_sigma (`float`, defaults to `9999999.0`):
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
The threshold above which the blur is considered infinite.
seg_guidance_start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
Diffusion 3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
seg_guidance_scale: float = 2.8,
seg_blur_sigma: float = 9999999.0,
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
self.seg_blur_sigma = seg_blur_sigma
self.seg_blur_threshold_inf = seg_blur_threshold_inf
self.seg_guidance_start = seg_guidance_start
self.seg_guidance_stop = seg_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= seg_guidance_start < 1.0):
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
if seg_guidance_layers is None and seg_guidance_config is None:
raise ValueError(
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
)
if seg_guidance_layers is not None and seg_guidance_config is not None:
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
if seg_guidance_layers is not None:
if isinstance(seg_guidance_layers, int):
seg_guidance_layers = [seg_guidance_layers]
if not isinstance(seg_guidance_layers, list):
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
if isinstance(seg_guidance_config, dict):
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]
if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
elif isinstance(next(iter(seg_guidance_config), None), dict):
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_cond_seg
pred = pred + self.seg_guidance_scale * shift
elif not self._is_seg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_seg = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_seg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -0,0 +1,143 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class TangentialClassifierFreeGuidance(BaseGuidance):
"""
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_tcfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_tcfg_enabled():
num_conditions += 1
return num_conditions
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
) -> torch.Tensor:
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
Vh_modified = Vh.clone()
Vh_modified[:, 1] = 0
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred
+2
View File
@@ -20,5 +20,7 @@ if is_torch_available():
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
+14 -1
View File
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
@@ -28,3 +34,10 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None
+10
View File
@@ -107,6 +107,7 @@ class TransformerBlockRegistry:
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0
AttentionProcessorRegistry.register(
@@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
),
)
# WanAttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on
+2 -1
View File
@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
+19 -2
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -37,7 +38,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -82,6 +83,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -100,7 +102,10 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
all_tensors = []
for module in self.modules:
@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=False,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
raise ValueError("Group offloading is not enabled for the provided module.")
def _compute_group_hash(group_id):
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
# first 16 characters for a reasonably short but unique name
return hashed_id[:16]
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
+263
View File
@@ -0,0 +1,263 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
skip_attention (`bool`, defaults to `True`):
Whether to skip attention blocks.
skip_ff (`bool`, defaults to `True`):
Whether to skip feed-forward blocks.
skip_attention_scores (`bool`, defaults to `False`):
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
projections as the output of scaled dot product attention.
dropout (`float`, defaults to `1.0`):
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
dropout: float = 1.0
def __post_init__(self):
if not (0 <= self.dropout <= 1):
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(data: dict) -> "LayerSkipConfig":
return LayerSkipConfig(**data)
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
query = kwargs.get("query", None)
key = kwargs.get("key", None)
value = kwargs.get("value", None)
query = query if query is not None else args[0]
key = key if key is not None else args[1]
value = value if value is not None else args[2]
# If the Q sequence length does not match KV sequence length, methods like
# Perturbed Attention Guidance cannot be used (because the caller expects
# the same sequence length as Q, but if we return V here, it will not match).
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
if query.shape[2] == value.shape[2]:
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
if not math.isclose(self.dropout, 1.0):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class FeedForwardSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class TransformerBlockSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
if self._metadata.return_encoder_hidden_states_index is None:
output = original_hidden_states
else:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
output = (original_hidden_states, original_encoder_hidden_states)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook(config.dropout)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
@@ -0,0 +1,167 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict, dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from ..utils import get_logger
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
@dataclass
class SmoothedEnergyGuidanceConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
`None`, `to_q` is used by default.
"""
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: List[str] = None
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
return SmoothedEnergyGuidanceConfig(**data)
class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
super().__init__()
self.blur_sigma = blur_sigma
self.blur_threshold_inf = blur_threshold_inf
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
return smoothed_output
def _apply_smoothed_energy_guidance_hook(
module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
for identifier in config._query_proj_identifiers:
query_proj = getattr(submodule, identifier, None)
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
continue
logger.debug(
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
)
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
"""
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
Smoothed Energy Guidance.
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
seq_len_sqrt = int(math.sqrt(seq_len))
num_square_tokens = seq_len_sqrt * seq_len_sqrt
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
kernel1d = kernel1d.to(query)
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query
+4
View File
@@ -78,12 +78,14 @@ if is_torch_available():
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
"ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -101,6 +103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
@@ -117,6 +120,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
SkyReelsV2LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
+255 -4
View File
@@ -40,8 +40,6 @@ if is_transformers_available():
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -354,6 +352,256 @@ class IPAdapterMixin:
self.unet.set_attn_processor(attn_procs)
class ModularIPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
if not USE_PEFT_BACKEND:
logger.warning("PEFT backend is required to load these weights.")
else:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
def set_ip_adapter_scale(self, scale):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
```
"""
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove hidden encoder
if self.unet is None:
return
self.unet.encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = None
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
self.unet.text_encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = "text_proj"
# restore original Unet attention processors layers
attn_procs = {}
for name, value in self.unet.attn_processors.items():
attn_processor_class = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
)
attn_procs[name] = (
attn_processor_class
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@@ -617,6 +865,9 @@ class FluxIPAdapterMixin:
>>> ...
```
"""
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
@@ -636,9 +887,9 @@ class FluxIPAdapterMixin:
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
+5 -2
View File
@@ -25,7 +25,6 @@ import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
@@ -331,6 +330,8 @@ def _load_lora_into_text_encoder(
hotswap: bool = False,
metadata=None,
):
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -442,6 +443,8 @@ def _func_optionally_disable_offloading(_pipeline):
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
from ..hooks.group_offloading import _is_group_offload_enabled
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
@@ -467,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
+398
View File
@@ -5454,6 +5454,404 @@ class WanLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
elif any(k.startswith("lora_unet_") for k in state_dict):
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
@classmethod
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
def _maybe_expand_t2v_lora_for_i2v(
cls,
transformer: torch.nn.Module,
state_dict,
):
if transformer.config.image_dim is None:
return state_dict
target_device = transformer.device
if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
has_bias = any(".lora_B.bias" in k for k in state_dict)
if is_i2v_lora:
return state_dict
for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
# These keys should exist if the block `i` was part of the T2V LoRA.
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
continue
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
)
# If the original LoRA had biases (indicated by has_bias)
# AND the specific reference bias key exists for this block.
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
if has_bias and ref_key_lora_B_bias in state_dict:
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
ref_lora_B_bias_tensor,
device=target_device,
)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`SkyReelsV2Transformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class CogView4LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
+3 -1
View File
@@ -22,7 +22,6 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -164,6 +163,8 @@ class PeftAdapterMixin:
from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -695,6 +696,7 @@ class PeftAdapterMixin:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
@@ -24,6 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
empty_device_cache()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -46,6 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -1689,6 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2148,6 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint)
+7 -9
View File
@@ -18,11 +18,8 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
@@ -84,13 +81,12 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
if low_cpu_mem_usage:
if is_accelerate_available():
@@ -122,7 +118,7 @@ class FluxTransformer2DLoadersMixin:
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
@@ -158,6 +154,8 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+4
View File
@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -80,6 +81,8 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
empty_device_cache()
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +150,7 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_proj
+6 -1
View File
@@ -22,7 +22,6 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
@@ -44,6 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -132,6 +132,8 @@ class UNet2DConditionLoadersMixin:
)
```
"""
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -752,6 +754,7 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
@@ -849,6 +852,8 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+4
View File
@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -88,6 +89,7 @@ if is_torch_available():
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
@@ -111,6 +113,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
@@ -176,6 +179,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PriorTransformer,
SanaTransformer2DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
+484 -3
View File
@@ -11,23 +11,504 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
if is_xformers_available():
import xformers as xops
else:
xops = None
logger = logging.get_logger(__name__)
class AttentionMixin:
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
"""
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.fuse_projections()
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.unfuse_projections()
class AttentionModuleMixin:
_default_processor_cls = None
_available_processors = []
fused_projections = False
def set_processor(self, processor: AttentionProcessor) -> None:
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def set_attention_backend(self, backend: str):
from .attention_dispatch import AttentionBackendName
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend.lower())
self.processor._attention_backend = backend
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
Set whether to use NPU flash attention from `torch_npu` or not.
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
if use_npu_flash_attention:
if not is_torch_npu_available():
raise ImportError("torch_npu is not available")
self.set_attention_backend("_native_npu")
def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
"""
Set whether to use XLA flash attention from `torch_xla` or not.
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
if use_xla_flash_attention:
if not is_torch_xla_available():
raise ImportError("torch_xla is not available")
self.set_attention_backend("_native_xla")
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
if is_xformers_available():
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
self.set_attention_backend("xformers")
@torch.no_grad()
def fuse_projections(self):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
# Skip if already fused
if getattr(self, "fused_projections", False):
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
else:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
# Skip if not fused
if not getattr(self, "fused_projections", False):
return
# Remove fused projection layers
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_qkv"):
delattr(self, "to_added_qkv")
self.fused_projections = False
def set_attention_slice(self, slice_size: int) -> None:
"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
processor = None
# Try to get a compatible processor for sliced attention
if slice_size is not None:
processor = self._get_compatible_processor("sliced")
# If no processor was found or slice_size is None, use default processor
if processor is None:
processor = self.default_processor_cls()
self.set_processor(processor)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
"""
Reshape the tensor for multi-head attention processing.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`): The attention mask to prepare.
target_length (`int`): The target length of the attention mask.
batch_size (`int`): The batch size for repeating the attention mask.
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
"""
Normalize the encoder hidden states.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
File diff suppressed because it is too large Load Diff
+105 -663
View File
@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
return hidden_states
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
self.partition_spec = partition_spec
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query /= math.sqrt(head_dim)
hidden_states = flash_attention(query, key, value, causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
return hidden_states
class FluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FusedFluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxIPAdapterJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
return FluxIPAdapterAttnProcessor(*args, **kwargs)
class FluxAttnProcessor2_0_NPU:
def __new__(cls, *args, **kwargs):
deprecation_message = (
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_native_npu"
return processor
class FusedFluxAttnProcessor2_0_NPU:
def __new__(self):
deprecation_message = (
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_fused_npu"
return processor
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __new__(cls, *args, **kwargs):
deprecation_message = (
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
"alternative solution to using XLA Flash Attention will be provided in the future."
)
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
from .transformers.transformer_flux import FluxAttnProcessor
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
deprecation_message = (
"partition_spec was not used in the processor implementation when it was added. Passing it "
"is a no-op and support for it will be removed."
)
deprecate("partition_spec", "1.0.0", deprecation_message)
processor = FluxAttnProcessor(*args, **kwargs)
processor._attention_backend = "_native_xla"
return processor
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
if from_multi:
if from_multi or len(control_type_idx) == 1:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
if from_multi:
if from_multi or len(control_type_idx) == 1:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
if from_multi:
if from_multi or len(control_type_idx) == 1:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
elif from_multi:
elif from_multi or len(control_type_idx) == 1:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
+26 -34
View File
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
"""
This function generates 1D positional embeddings from a grid.
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
return emb
@@ -1176,6 +1181,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1193,8 +1199,15 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
if sequence_dim == 2:
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
@@ -1238,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
return x
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -2619,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
projected_image_embeds.append(image_embed)
return projected_image_embeds
class FluxPosEmbed(nn.Module):
def __new__(cls, *args, **kwargs):
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxPosEmbed
return FluxPosEmbed(*args, **kwargs)
+64 -1
View File
@@ -16,9 +16,10 @@
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,6 +39,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+89 -41
View File
@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -168,7 +172,11 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
for name, param in parameter.named_parameters():
last_dtype = param.dtype
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
if (
hasattr(parameter, "_keep_in_fp32_modules")
and parameter._keep_in_fp32_modules
and any(m in name for m in parameter._keep_in_fp32_modules)
):
continue
if param.is_floating_point():
@@ -602,6 +610,60 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_to_disk_path=offload_to_disk_path,
)
def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.
Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -1469,11 +1531,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1482,18 +1539,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1569,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)
if low_cpu_mem_usage:
@@ -1554,9 +1596,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
empty_device_cache()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
@@ -1892,4 +1935,9 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
if remapped_class is cls:
return super(LegacyModelMixin, remapped_class).from_pretrained(
pretrained_model_name_or_path, **kwargs_copy
)
else:
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
@@ -31,6 +31,7 @@ if is_torch_available():
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
processor = FluxAttnProcessor()
self.attn = Attention(
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = Attention(
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
processor=FluxAttnProcessor(),
eps=eps,
)
@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Flux, modified for Chroma.
@@ -475,106 +473,6 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class FluxAttnProcessor:
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
ip_query = query
if encoder_hidden_states is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
processor = FluxAttnProcessor()
self.attn = Attention(
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -118,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention(
self.attn = FluxAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
processor=FluxAttnProcessor(),
eps=eps,
)
@@ -152,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -180,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
@@ -195,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Flux.
@@ -292,106 +628,6 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
@@ -490,6 +726,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -521,6 +758,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -0,0 +1,607 @@
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import (
PixArtAlphaTextProjection,
TimestepEmbedding,
get_1d_rotary_pos_embed,
get_1d_sincos_pos_embed_from_grid,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin, get_parameter_dtype
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SkyReelsV2AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
class SkyReelsV2ImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class SkyReelsV2Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
super().__init__()
self.num_channels = num_channels
self.output_type = output_type
self.flip_sin_to_cos = flip_sin_to_cos
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
original_shape = timesteps.shape
t_emb = get_1d_sincos_pos_embed_from_grid(
self.num_channels,
timesteps,
output_type=self.output_type,
flip_sin_to_cos=self.flip_sin_to_cos,
)
# Reshape back to maintain batch structure
if len(original_shape) > 1:
t_emb = t_emb.reshape(*original_shape, self.num_channels)
return t_emb
class SkyReelsV2TimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = get_parameter_dtype(self.time_embedder)
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class SkyReelsV2RotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
class SkyReelsV2TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
)
# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
if temb.dim() == 3:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
elif temb.dim() == 4:
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `16`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `4096`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `8192`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `32`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
inject_sample_info (`bool`, defaults to `False`):
Whether to inject sample information into the model.
image_dim (`int`, *optional*):
The dimension of the image embeddings.
added_kv_proj_dim (`int`, *optional*):
The dimension of the added key/value projection.
rope_max_seq_len (`int`, defaults to `1024`):
The maximum sequence length for the rotary embeddings.
pos_embed_seq_len (`int`, *optional*):
The sequence length for the positional embeddings.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["SkyReelsV2TransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 16,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 8192,
num_layers: int = 32,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
inject_sample_info: bool = False,
num_frame_per_block: int = 1,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
SkyReelsV2TransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
if inject_sample_info:
self.fps_embedding = nn.Embedding(2, inner_dim)
self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
enable_diffusion_forcing: bool = False,
fps: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
rotary_emb = self.rope(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
causal_mask = None
if self.config.num_frame_per_block > 1:
block_num = post_patch_num_frames // self.config.num_frame_per_block
range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
self.config.num_frame_per_block
)
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
causal_mask = causal_mask.repeat(
1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
)
causal_mask = causal_mask.reshape(
post_patch_num_frames * post_patch_height * post_patch_width,
post_patch_num_frames * post_patch_height * post_patch_width,
)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
if self.config.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
fps_emb = self.fps_embedding(fps)
if enable_diffusion_forcing:
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
timestep.shape[1], 1, 1
)
else:
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
if enable_diffusion_forcing:
b, f = timestep.shape
temb = temb.view(b, f, 1, 1, -1)
timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
1, 3
) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
causal_mask,
)
else:
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
causal_mask,
)
if temb.dim() == 2:
# If temb is 2D, we assume it has time 1-D time embedding values for each batch.
# For models:
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
elif temb.dim() == 3:
# If temb is 3D, we assume it has 2-D time embedding values for each batch.
# Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
# For models:
# - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = shift.squeeze(1), scale.squeeze(1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _set_ar_attention(self, causal_block_size: int):
self.register_to_config(num_frame_per_block=causal_block_size)
@@ -165,7 +165,7 @@ class UNet2DConditionModel(
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["BasicTransformerBlock"]
@@ -0,0 +1,86 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
"PipelineState",
"BlockState",
]
_import_structure["modular_pipeline_utils"] = [
"ComponentSpec",
"ConfigSpec",
"InputParam",
"OutputParam",
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
InsertableDict,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,673 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import is_torch_available, logging
if is_torch_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class InsertableDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
def __repr__(self):
if not self:
return "InsertableDict()"
items = []
for i, (key, value) in enumerate(self.items()):
if isinstance(value, type):
# For classes, show class name and <class ...>
obj_repr = f"<class '{value.__module__}.{value.__name__}'>"
else:
# For objects (instances) and other types, show class name and module
obj_repr = f"<obj '{value.__class__.__module__}.{value.__class__.__name__}'>"
items.append(f"{i}: ({repr(key)}, {obj_repr})")
return "InsertableDict([\n " + ",\n ".join(items) + "\n])"
# YiYi TODO:
# 1. validate the dataclass fields
# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained()
@dataclass
class ComponentSpec:
"""Specification for a pipeline component.
A component can be created in two ways:
1. From scratch using __init__ with a config dict
2. using `from_pretrained`
Attributes:
name: Name of the component
type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component
config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation
subfolder: Optional subfolder in repo
variant: Optional variant in repo
revision: Optional revision in repo
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
"""
name: Optional[str] = None
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (
self.name == other.name
and self.load_id == other.load_id
and self.default_creation_method == other.default_creation_method
)
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component.
Currently supports:
- Components created with `ComponentSpec.load()` method
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
Args:
name: Name of the component
component: Component object to create spec from
Returns:
ComponentSpec object
Raises:
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
"""
# Check if component was created with ComponentSpec.load()
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
# component has a usable load_id -> from_pretrained, no warning needed
default_creation_method = "from_pretrained"
else:
# Component doesn't have a usable load_id, check if it's a nn.Module
if isinstance(component, torch.nn.Module):
raise ValueError(
"Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
)
# ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
elif isinstance(component, ConfigMixin):
# warn if component was not created with `ComponentSpec`
if not hasattr(component, "_diffusers_load_id"):
logger.warning(
"Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
)
default_creation_method = "from_config"
else:
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
raise ValueError(
f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
f" - components created with `ComponentSpec.load()` method"
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
)
type_hint = component.__class__
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
config = component.config
else:
config = None
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
load_spec = cls.decode_load_id(component._diffusers_load_id)
else:
load_spec = {}
return cls(
name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
)
@classmethod
def loading_fields(cls) -> List[str]:
"""
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
segments).
"""
if self.default_creation_method == "from_config":
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
Decode a load_id string back into a dictionary of loading fields and values.
Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
where None values are represented as "null"
Returns:
Dict mapping loading field names to their values. e.g. {
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
component not created with `load` method).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
# remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError("`type_hint` is required when using from_config creation method.")
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
signature_params = inspect.signature(self.type_hint.__init__).parameters
init_kwargs = {}
for k, v in config.items():
if k in signature_params:
init_kwargs[k] = v
for k, v in kwargs.items():
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
)
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
default: Any
description: Optional[str] = None
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
# -> should we use different class for inputs and intermediate_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return (
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)
def format_inputs_short(inputs):
"""
Format input parameters into a string representation, with required params first followed by optional ones.
Args:
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
Returns:
str: Formatted string of input parameters
Example:
>>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ...
InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps",
required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5,
num_inference_steps=50'
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
"""
Formats intermediate inputs and outputs of a block into a string representation.
Args:
intermediate_inputs: List of intermediate input parameters
required_intermediate_inputs: List of required intermediate input names
intermediate_outputs: List of intermediate output parameters
Returns:
str: Formatted string like:
Intermediates:
- inputs: Required(latents), dtype
- modified: latents # variables that appear in both inputs and outputs
- outputs: images # new outputs only
"""
# Handle inputs
input_parts = []
for inp in intermediate_inputs:
if inp.name in required_intermediate_inputs:
input_parts.append(f"Required({inp.name})")
else:
if inp.name is None and inp.kwargs_type is not None:
inp_name = "*_" + inp.kwargs_type
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediate_inputs}
modified_parts = []
new_output_parts = []
for out in intermediate_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
if modified_parts:
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""Format a list of InputParam or OutputParam objects into a readable string representation.
Args:
params: List of InputParam or OutputParam objects to format
header: Header text to use (e.g. "Args" or "Returns")
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all parameters
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
lines = []
current_line = []
current_length = 0
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
current_length = len(word)
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
# YiYi Notes: remove this line if we remove kwargs_type
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
def format_input_params(input_params, indent_level=4, max_line_length=115):
"""Format a list of InputParam objects into a readable string representation.
Args:
input_params: List of InputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all input parameters
"""
return format_params(input_params, "Inputs", indent_level, max_line_length)
def format_output_params(output_params, indent_level=4, max_line_length=115):
"""Format a list of OutputParam objects into a readable string representation.
Args:
output_params: List of OutputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all output parameters
"""
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
Args:
components: List of ComponentSpec objects to format
indent_level: Number of spaces to indent each component line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between components (default: True)
Returns:
A formatted string representing all components
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = (
component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ConfigSpec objects into a readable string representation.
Args:
configs: List of ConfigSpec objects to format
indent_level: Number of spaces to indent each config line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between configs (default: True)
Returns:
A formatted string representing all configs
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
def make_doc_string(
inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
expected_components=None,
expected_configs=None,
):
"""
Generates a formatted documentation string describing the pipeline block's parameters and structure.
Args:
inputs: List of input parameters
intermediate_inputs: List of intermediate input parameters
outputs: List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
expected_components (List[ComponentSpec], *optional*): List of expected components
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
Returns:
str: A formatted string containing information about components, configs, call parameters,
intermediate inputs/outputs, and final outputs.
"""
output = ""
# Add class name if provided
if class_name:
output += f"class {class_name}\n\n"
# Add description
if description:
desc_lines = description.strip().split("\n")
aligned_desc = "\n".join(" " + line for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
@@ -0,0 +1,665 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS = {
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
"""
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components_manager, collection=None):
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
# Save the config to file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.pipeline.update_components(**params_components)
output = self.pipeline(**params_run, output=return_output_names)
return output
@@ -0,0 +1,77 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"IP_ADAPTER_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLAutoControlnetStep",
"StableDiffusionXLAutoDecodeStep",
"StableDiffusionXLAutoIPAdapterStep",
"StableDiffusionXLAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import (
StableDiffusionXLTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
TEXT2IMAGE_BLOCKS,
StableDiffusionXLAutoBlocks,
StableDiffusionXLAutoControlnetStep,
StableDiffusionXLAutoDecodeStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
)
from .modular_pipeline import StableDiffusionXLModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,215 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
def upcast_vae(components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = (
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
)
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
# apply watermark if available
if hasattr(components, "watermark") and components.watermark is not None:
block_state.images = components.watermark.apply_watermark(block_state.images)
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"A post-processing step that overlays the mask on the image (inpainting task only).\n"
+ "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images from the decode step",
),
InputParam(
"crops_coords",
type_hint=Tuple[int, int],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
block_state.images = [
components.image_processor.apply_overlay(
block_state.mask_image, block_state.image, i, block_state.crops_coords
)
for i in block_state.images
]
self.set_block_state(state, block_state)
return components, state
@@ -0,0 +1,791 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import ControlNetModel, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that prepare the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
return components, block_state
# loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object"
)
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
InputParam(
"masked_image_latents",
type_hint=Optional[torch.Tensor],
description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
]
@staticmethod
def check_inputs(components, block_state):
num_channels_unet = components.num_channels_unet
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
if block_state.mask is None or block_state.masked_image_latents is None:
raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
num_channels_latents = block_state.latents.shape[1]
num_channels_mask = block_state.mask.shape[1]
num_channels_masked_image = block_state.masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
raise ValueError(
f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `components.unet` or your `mask_image` or `image` input."
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
if components.num_channels_unet == 9:
block_state.scaled_latents = torch.cat(
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
)
return components, block_state
# loop step (2): denoise the latents with guidance
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=cond_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that denoise the latents with guidance (with controlnet). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
required=True,
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"conditioning_scale",
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"guess_mode",
required=True,
type_hint=bool,
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
InputParam(
kwargs_type="controlnet_kwargs",
description=(
"additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
"please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
extra_controlnet_kwargs = self.prepare_extra_kwargs(
components.controlnet.forward, **block_state.controlnet_kwargs
)
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
# cond_scale for the timestep (controlnet input)
if isinstance(block_state.controlnet_keep[i], list):
block_state.cond_scale = [
c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])
]
else:
controlnet_cond_scale = block_state.conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
# default controlnet output/unet input for guess mode + conditional path
block_state.down_block_res_samples_zeros = None
block_state.mid_block_res_sample_zeros = None
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
# Prepare additional conditionings
added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
# Prepare controlnet additional conditionings
controlnet_added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
# run controlnet for the guidance batch
if block_state.guess_mode and not components.guider.is_conditional:
# guider always run uncond batch first, so these tensors should be set already
down_block_res_samples = block_state.down_block_res_samples_zeros
mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.controlnet_cond,
conditioning_scale=block_state.cond_scale,
guess_mode=block_state.guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
**extra_controlnet_kwargs,
)
# assign it to block_state so it will be available for the uncond guidance batch
if block_state.down_block_res_samples_zeros is None:
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
if block_state.mid_block_res_sample_zeros is None:
block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
# Predict the noise
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (3): scheduler step to update latents
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
# YiYi TODO: move this out of here
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
return components, block_state
# loop step (3): scheduler step to update latents (with inpainting)
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents (for inpainting workflow only). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
InputParam(
"noise",
type_hint=Optional[torch.Tensor],
description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
type_hint=Optional[torch.Tensor],
description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
def check_inputs(self, components, block_state):
if components.num_channels_unet == 4:
if block_state.image_latents is None:
raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
if block_state.mask is None:
raise ValueError(f"mask is required for this step {self.__class__.__name__}")
if block_state.noise is None:
raise ValueError(f"noise is required for this step {self.__class__.__name__}")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
# adjust latent for inpainting
if components.num_channels_unet == 4:
block_state.init_latents_proper = block_state.image_latents
if i < len(block_state.timesteps) - 1:
block_state.noise_timestep = block_state.timesteps[i + 1]
block_state.init_latents_proper = components.scheduler.add_noise(
block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
)
block_state.latents = (
1 - block_state.mask
) * block_state.init_latents_proper + block_state.mask * block_state.latents
return components, block_state
# the loop wrapper that iterates over the timesteps
class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if block_state.disable_guidance:
components.guider.disable()
else:
components.guider.enable()
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
# composing the denoising loops
class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLLoopBeforeDenoiser,
StableDiffusionXLLoopDenoiser,
StableDiffusionXLLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
"This block supports both text2img and img2img tasks."
)
# control_cond
class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLLoopBeforeDenoiser,
StableDiffusionXLControlNetLoopDenoiser,
StableDiffusionXLLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
"This block supports using controlnet for both text2img and img2img tasks."
)
# mask
class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLInpaintLoopBeforeDenoiser,
StableDiffusionXLLoopDenoiser,
StableDiffusionXLInpaintLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
"This block onlysupports inpainting tasks."
)
# control_cond + mask
class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLInpaintLoopBeforeDenoiser,
StableDiffusionXLControlNetLoopDenoiser,
StableDiffusionXLInpaintLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
"This block only supports using controlnet for inpainting tasks."
)
@@ -0,0 +1,902 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"IP Adapter step that prepares ip adapter image embeddings.\n"
"Note that this step only prepares the embeddings - in order for it to work correctly, "
"you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
ComponentSpec(
"feature_extractor",
CLIPImageProcessor,
config=FrozenDict({"size": 224, "crop_size": 224}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter",
)
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=torch.Tensor,
description="Negative IP adapter image embeddings",
),
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = components.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = components.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = components.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
components,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
prepare_unconditional_embeds,
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("clip_skip"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative pooled text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = (
[components.tokenizer, components.tokenizer_2]
if components.tokenizer is not None
else [components.tokenizer_2]
)
text_encoders = (
[components.text_encoder, components.text_encoder_2]
if components.text_encoder is not None
else [components.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=components.text_encoder_2.dtype, device=device
)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None)
if block_state.cross_attention_kwargs is not None
else None
)
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict(
{"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Vae encoder step that prepares the image and mask for the inpainting process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam(
"masked_image_latents",
type_hint=torch.Tensor,
description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
),
OutputParam(
"crops_coords",
type_hint=Optional[Tuple[int, int]],
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.height is None:
block_state.height = components.default_height
if block_state.width is None:
block_state.width = components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
)
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(
block_state.image,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode,
)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(
block_state.mask_image,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state
@@ -0,0 +1,380 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLInputStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintOverlayMaskStep,
)
from .denoise import (
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLInpaintControlNetDenoiseStep,
StableDiffusionXLInpaintDenoiseStep,
)
from .encoders import (
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# auto blocks & sequential blocks & mappings
# vae encoder (run before before_denoise)
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for both inpainting and img2img tasks.\n"
+ " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if neither `mask_image` nor `image` is provided, step will be skipped."
)
# optional ip-adapter (run before input step)
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
# before_denoise: text2img
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: img2img
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: inpainting
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: all task (text2img, img2img, inpainting)
class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
StableDiffusionXLInpaintBeforeDenoiseStep,
StableDiffusionXLImg2ImgBeforeDenoiseStep,
StableDiffusionXLBeforeDenoiseStep,
]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n"
+ " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n"
+ " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n"
)
# optional controlnet input step (after before_denoise, before denoise)
# works for both controlnet and controlnet_union
class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
block_names = ["controlnet_union", "controlnet"]
block_trigger_inputs = ["control_mode", "control_image"]
@property
def description(self):
return (
"Controlnet Input step that prepare the controlnet input.\n"
+ "This is an auto pipeline block that works for both controlnet and controlnet_union.\n"
+ " (it should be called right before the denoise step)"
+ " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n"
+ " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided."
+ " - if neither `control_mode` nor `control_image` is provided, step will be skipped."
)
# denoise: controlnet (text2img, img2img, inpainting)
class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep]
block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
block_trigger_inputs = ["mask", "controlnet_cond"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. "
"This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks."
"This block should not be used without a controlnet_cond input"
" - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided."
" - If neither mask nor controlnet_cond are provided, step will be skipped."
)
# denoise: all task with or without controlnet (text2img, img2img, inpainting)
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
StableDiffusionXLAutoControlNetDenoiseStep,
StableDiffusionXLInpaintDenoiseStep,
StableDiffusionXLDenoiseStep,
]
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", "mask", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
" - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)."
" - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)."
" - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)."
)
# decode: inpaint
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
block_names = ["decode", "mask_overlay"]
@property
def description(self):
return (
"Inpaint decode step that decode the denoised latents into images outputs.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n"
+ " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
)
# decode: all task (text2img, img2img, inpainting)
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
block_names = ["inpaint", "non-inpaint"]
block_trigger_inputs = ["padding_mask_crop", None]
@property
def description(self):
return (
"Decode step that decode the denoised latents into images outputs.\n"
+ "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n"
+ " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n"
+ " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
)
# ip-adapter, controlnet, text2img, img2img, inpainting
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLAutoDecodeStep,
]
block_names = [
"text_encoder",
"ip_adapter",
"image_encoder",
"before_denoise",
"controlnet_input",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
# controlnet (input + denoise step)
class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoControlNetDenoiseStep,
]
block_names = ["controlnet_input", "controlnet_denoise"]
@property
def description(self):
return (
"Controlnet auto step that prepare the controlnet input and denoise the latents. "
+ "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
+ " (it should be replace at 'denoise' step)"
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLInpaintDenoiseStep),
("decode", StableDiffusionXLInpaintDecodeStep),
]
)
CONTROLNET_BLOCKS = InsertableDict(
[
("denoise", StableDiffusionXLAutoControlnetStep),
]
)
IP_ADAPTER_BLOCKS = InsertableDict(
[
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep),
]
)
ALL_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS,
}
@@ -0,0 +1,376 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ...image_processor import PipelineImageInput
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularPipeline
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
# YiYi Notes: model specific components:
## (1) it should inherit from ModularPipeline
## (2) acts like a container that holds components and configs
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularPipeline(
ModularPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
"""
A ModularPipeline for Stable Diffusion XL.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
default_sample_size = 128
if hasattr(self, "unet") and self.unet is not None:
default_sample_size = self.unet.config.sample_size
return default_sample_size
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_unet(self):
num_channels_unet = 4
if hasattr(self, "unet") and self.unet is not None:
num_channels_unet = self.unet.config.in_channels
return num_channels_unet
@property
def num_channels_latents(self):
num_channels_latents = 4
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
# auto_docstring
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam(
"prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"
),
"negative_prompt_embeds": OutputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": OutputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": OutputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam(
"image_latents", type_hint=torch.Tensor, description="Latents representing reference image"
),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam(
"latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"
),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="Generated images",
),
}
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam(
"images",
type_hint=Union[
Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput
],
description="The final generated images",
)
}
@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
)
from .modular_pipeline import WanModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,365 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import torch
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
# configuration of guider is.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class WanInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
"have a final batch_size of batch_size * num_videos_per_prompt."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_videos_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
description="negative text embeddings used to guide the image generation",
),
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
f" {block_state.negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_videos_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
block_state.device,
block_state.timesteps,
block_state.sigmas,
)
self.set_block_state(state, block_state)
return components, state
class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("num_frames", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_videos_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
if block_state.num_frames is not None and (
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
):
raise ValueError(
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
)
@staticmethod
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
def prepare_latents(
comp,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // comp.vae_scale_factor_spatial,
int(width) // comp.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.num_frames = block_state.num_frames or components.default_num_frames
block_state.device = components._execution_device
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.latents = self.prepare_latents(
components,
block_state.batch_size * block_state.num_videos_per_prompt,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.num_frames,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
self.set_block_state(state, block_state)
return components, state
@@ -0,0 +1,105 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan
from ...utils import logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"videos",
type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype
if not block_state.output_type == "latent":
latents = block_state.latents
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
else:
block_state.videos = block_state.latents
block_state.videos = components.video_processor.postprocess_video(
block_state.videos, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state
@@ -0,0 +1,261 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import WanTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred.float(),
t,
block_state.latents.float(),
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", UniPCMultistepScheduler),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class WanDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
)
@@ -0,0 +1,242 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import List, Optional, Union
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", UMT5EncoderModel),
ComponentSpec("tokenizer", AutoTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]],
max_sequence_length: int,
device: torch.device,
):
dtype = components.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: str,
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, defaults to `512`):
The maximum number of text tokens to be used for the generation process.
"""
device = device or components._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
if prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
components, negative_prompt, max_sequence_length, device
)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
@@ -0,0 +1,144 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanInputStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
)
from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep
from .encoders import WanTextEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2vid,)
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanBeforeDenoiseStep,
]
block_names = ["text2vid"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2vid.\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
)
# denoise: text2vid
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanDenoiseStep,
]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2vid tasks.."
" - `WanDenoiseStep` (denoise) for text2vid tasks."
)
# decode: all task (text2img, img2img, inpainting)
class WanAutoDecodeStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep]
block_names = ["non-inpaint"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
# text2vid
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoBeforeDenoiseStep,
WanAutoDenoiseStep,
WanAutoDecodeStep,
]
block_names = [
"text_encoder",
"before_denoise",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep),
("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep),
]
)
ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS,
"auto": AUTO_BLOCKS,
}
@@ -0,0 +1,90 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanModularPipeline(
ModularPipeline,
StableDiffusionMixin,
WanLoraLoaderMixin,
):
"""
A ModularPipeline for Wan.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
@property
def default_width(self):
return self.default_sample_width * self.vae_scale_factor_spatial
@property
def default_num_frames(self):
return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
@property
def default_sample_height(self):
return 60
@property
def default_sample_width(self):
return 104
@property
def default_sample_num_frames(self):
return 21
@property
def vae_scale_factor_spatial(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
return vae_scale_factor
@property
def vae_scale_factor_temporal(self):
vae_scale_factor = 4
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
return vae_scale_factor
@property
def num_channels_transformer(self):
num_channels_transformer = 16
if hasattr(self, "transformer") and self.transformer is not None:
num_channels_transformer = self.transformer.config.in_channels
return num_channels_transformer
@property
def num_channels_latents(self):
num_channels_latents = 16
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.z_dim
return num_channels_latents
+15
View File
@@ -380,6 +380,13 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
"SkyReelsV2DiffusionForcingVideoToVideoPipeline",
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -851,6 +858,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SpectrogramDiffusionPipeline,
)
from .skyreels_v2 import (
SkyReelsV2DiffusionForcingImageToVideoPipeline,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2DiffusionForcingVideoToVideoPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
)
else:
import sys
+8 -7
View File
@@ -248,14 +248,15 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
def get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
def _get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
model_name = get_model(pipeline_class_name)
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
model_name = _get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -663,11 +663,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):

Some files were not shown because too many files have changed in this diff Show More