Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul e72648a311 Merge branch 'main' into gpu-pr-test 2025-07-09 09:04:31 +05:30
DN6 3e3c0fcc1c update 2025-07-08 22:38:50 +05:30
182 changed files with 2370 additions and 31128 deletions
-1
View File
@@ -13,7 +13,6 @@ on:
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
- "examples/**/*.py"
workflow_dispatch:
concurrency:
+1 -5
View File
@@ -47,10 +47,6 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
tensorboard \
transformers \
matplotlib \
setuptools==69.5.1 \
bitsandbytes \
torchao \
gguf \
optimum-quanto
setuptools==69.5.1
CMD ["/bin/bash"]
+172 -207
View File
@@ -1,39 +1,36 @@
- title: Get started
sections:
- sections:
- local: index
title: Diffusers
- local: installation
title: Installation
title: 🧨 Diffusers
- local: quicktour
title: Quicktour
- local: stable_diffusion
title: Effective and efficient diffusion
- title: DiffusionPipeline
isExpanded: false
sections:
- local: using-diffusers/loading
title: Load pipelines
- local: installation
title: Installation
title: Get started
- sections:
- local: tutorials/tutorial_overview
title: Overview
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: tutorials/autopipeline
title: AutoPipeline
- local: tutorials/basic_training
title: Train a diffusion model
title: Tutorials
- sections:
- local: using-diffusers/loading
title: Load pipelines
- local: using-diffusers/custom_pipeline_overview
title: Load community pipelines and components
- local: using-diffusers/callback
title: Pipeline callbacks
- local: using-diffusers/reusing_seeds
title: Reproducible pipelines
- local: using-diffusers/schedulers
title: Load schedulers and models
- local: using-diffusers/scheduler_features
title: Scheduler features
- local: using-diffusers/other-formats
title: Model files and layouts
- local: using-diffusers/push_to_hub
title: Push files to the Hub
- title: Adapters
isExpanded: false
sections:
title: Load pipelines and adapters
- sections:
- local: tutorials/using_peft_for_inference
title: LoRA
- local: using-diffusers/ip_adapter
@@ -46,12 +43,25 @@
title: DreamBooth
- local: using-diffusers/textual_inversion_inference
title: Textual inversion
- title: Inference
title: Adapters
isExpanded: false
sections:
- local: using-diffusers/weighted_prompts
title: Prompt techniques
- sections:
- local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- local: using-diffusers/conditional_image_generation
title: Text-to-image
- local: using-diffusers/img2img
title: Image-to-image
- local: using-diffusers/inpaint
title: Inpainting
- local: using-diffusers/text-img2vid
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
title: Generative tasks
- sections:
- local: using-diffusers/overview_techniques
title: Overview
- local: using-diffusers/create_a_server
title: Create a server
- local: using-diffusers/batched_inference
@@ -66,38 +76,14 @@
title: Reproducible pipelines
- local: using-diffusers/image_quality
title: Controlling image quality
- title: Inference optimization
isExpanded: false
sections:
- local: optimization/fp16
title: Accelerate inference
- local: optimization/cache
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compile and offloading quantized models
- title: Community optimizations
sections:
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
title: Token merging
- local: optimization/deepcache
title: DeepCache
- local: optimization/tgate
title: TGATE
- local: optimization/xdit
title: xDiT
- local: optimization/para_attn
title: ParaAttention
- title: Hybrid Inference
isExpanded: false
sections:
- local: using-diffusers/weighted_prompts
title: Prompt techniques
title: Inference techniques
- sections:
- local: advanced_inference/outpaint
title: Outpainting
title: Advanced inference
- sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
@@ -106,110 +92,8 @@
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
- title: Modular Diffusers
isExpanded: false
sections:
- local: modular_diffusers/overview
title: Overview
- local: modular_diffusers/modular_pipeline
title: Modular Pipeline
- local: modular_diffusers/components_manager
title: Components Manager
- local: modular_diffusers/modular_diffusers_states
title: Modular Diffusers States
- local: modular_diffusers/pipeline_block
title: Pipeline Block
- local: modular_diffusers/sequential_pipeline_blocks
title: Sequential Pipeline Blocks
- local: modular_diffusers/loop_sequential_pipeline_blocks
title: Loop Sequential Pipeline Blocks
- local: modular_diffusers/auto_pipeline_blocks
title: Auto Pipeline Blocks
- local: modular_diffusers/end_to_end_guide
title: End-to-End Example
- title: Training
isExpanded: false
sections:
- local: training/overview
title: Overview
- local: training/create_dataset
title: Create a dataset for training
- local: training/adapt_a_model
title: Adapt a model to a new task
- local: tutorials/basic_training
title: Train a diffusion model
- title: Models
sections:
- local: training/unconditional_training
title: Unconditional image generation
- local: training/text2image
title: Text-to-image
- local: training/sdxl
title: Stable Diffusion XL
- local: training/kandinsky
title: Kandinsky 2.2
- local: training/wuerstchen
title: Wuerstchen
- local: training/controlnet
title: ControlNet
- local: training/t2i_adapters
title: T2I-Adapters
- local: training/instructpix2pix
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
- title: Methods
sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
title: DreamBooth
- local: training/lora
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
- local: training/lcm_distill
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
- title: Quantization
isExpanded: false
sections:
- local: quantization/overview
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
- title: Model accelerators and hardware
isExpanded: false
sections:
- local: using-diffusers/stable_diffusion_jax_how_to
title: JAX/Flax
- local: optimization/onnx
title: ONNX
- local: optimization/open_vino
title: OpenVINO
- local: optimization/coreml
title: Core ML
- local: optimization/mps
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
- title: Specific pipeline examples
isExpanded: false
sections:
title: Hybrid Inference
- sections:
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -234,30 +118,106 @@
title: Stable Video Diffusion
- local: using-diffusers/marigold_usage
title: Marigold Computer Vision
- title: Resources
isExpanded: false
sections:
- title: Task recipes
title: Specific pipeline examples
- sections:
- local: training/overview
title: Overview
- local: training/create_dataset
title: Create a dataset for training
- local: training/adapt_a_model
title: Adapt a model to a new task
- isExpanded: false
sections:
- local: using-diffusers/unconditional_image_generation
- local: training/unconditional_training
title: Unconditional image generation
- local: using-diffusers/conditional_image_generation
- local: training/text2image
title: Text-to-image
- local: using-diffusers/img2img
title: Image-to-image
- local: using-diffusers/inpaint
title: Inpainting
- local: advanced_inference/outpaint
title: Outpainting
- local: using-diffusers/text-img2vid
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: community_projects
title: Projects built with Diffusers
- local: training/sdxl
title: Stable Diffusion XL
- local: training/kandinsky
title: Kandinsky 2.2
- local: training/wuerstchen
title: Wuerstchen
- local: training/controlnet
title: ControlNet
- local: training/t2i_adapters
title: T2I-Adapters
- local: training/instructpix2pix
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
title: Models
- isExpanded: false
sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
title: DreamBooth
- local: training/lora
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
- local: training/lcm_distill
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
title: Methods
title: Training
- sections:
- local: quantization/overview
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
title: Accelerate inference
- local: optimization/cache
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compile and offloading quantized models
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
title: Token merging
- local: optimization/deepcache
title: DeepCache
- local: optimization/tgate
title: TGATE
- local: optimization/xdit
title: xDiT
- local: optimization/para_attn
title: ParaAttention
- sections:
- local: using-diffusers/stable_diffusion_jax_how_to
title: JAX/Flax
- local: optimization/onnx
title: ONNX
- local: optimization/open_vino
title: OpenVINO
- local: optimization/coreml
title: Core ML
title: Optimized model formats
- sections:
- local: optimization/mps
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
title: Optimized hardware
title: Accelerate inference and reduce memory
- sections:
- local: conceptual/philosophy
title: Philosophy
- local: using-diffusers/controlling_generation
@@ -268,11 +228,13 @@
title: Diffusers' Ethical Guidelines
- local: conceptual/evaluation
title: Evaluating Diffusion Models
- title: API
isExpanded: false
sections:
- title: Main Classes
title: Conceptual Guides
- sections:
- local: community_projects
title: Projects built with Diffusers
title: Community Projects
- sections:
- isExpanded: false
sections:
- local: api/configuration
title: Configuration
@@ -282,7 +244,8 @@
title: Outputs
- local: api/quantization
title: Quantization
- title: Loaders
title: Main Classes
- isExpanded: false
sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
@@ -298,14 +261,14 @@
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
- title: Models
title: Loaders
- isExpanded: false
sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- title: ControlNets
sections:
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
@@ -320,8 +283,8 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- title: Transformers
sections:
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
@@ -370,8 +333,6 @@
title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
title: SkyReelsV2Transformer3DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
@@ -380,8 +341,8 @@
title: TransformerTemporalModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- title: UNets
sections:
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
@@ -396,8 +357,8 @@
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
- title: VAEs
sections:
title: UNets
- sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -428,7 +389,9 @@
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
- title: Pipelines
title: VAEs
title: Models
- isExpanded: false
sections:
- local: api/pipelines/overview
title: Overview
@@ -564,14 +527,11 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/skyreels_v2
title: SkyReels-V2
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- title: Stable Diffusion
sections:
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
@@ -608,6 +568,7 @@
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/text_to_video
@@ -626,7 +587,8 @@
title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
- title: Schedulers
title: Pipelines
- isExpanded: false
sections:
- local: api/schedulers/overview
title: Overview
@@ -696,7 +658,8 @@
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- title: Internal classes
title: Schedulers
- isExpanded: false
sections:
- local: api/internal_classes_overview
title: Overview
@@ -714,3 +677,5 @@
title: VAE Image Processor
- local: api/video_processor
title: Video Processor
title: Internal classes
title: API
+2 -7
View File
@@ -26,7 +26,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
@@ -93,10 +92,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
## SkyReelsV2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
@@ -105,6 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## LoraBaseMixin
## WanLoraLoaderMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
@@ -1,30 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SkyReelsV2Transformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.
The model can be loaded with the following code snippet.
```python
from diffusers import SkyReelsV2Transformer3DModel
transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SkyReelsV2Transformer3DModel
[[autodoc]] SkyReelsV2Transformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+1 -1
View File
@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
pipe.enabe_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
-367
View File
@@ -1,367 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# SkyReels-V2: Infinite-length Film Generative model
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team.
*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
You can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization.
The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers)
- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers)
- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers)
- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers)
- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers)
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
### A _Visual_ Demonstration
An example with these parameters:
base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5
vae_scale_factor_temporal -> 4
num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each
base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 → blocks = 25//5 = 5 blocks
This 5 blocks means the maximum context length of the model is 25 frames in the latent space.
Asynchronous Processing Timeline:
┌─────────────────────────────────────────────────────────────────┐
│ Steps: 1 6 11 16 21 26 31 36 41 46 50 │
│ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
│ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
└─────────────────────────────────────────────────────────────────┘
For Long Videos (num_frames > base_num_frames):
base_num_frames acts as the "sliding window size" for processing long videos.
Example: 257-frame video with base_num_frames=97, overlap_history=17
┌──── Iteration 1 (frames 1-97) ────┐
│ Processing window: 97 frames │ → 5 blocks, async processing
│ Generates: frames 1-97 │
└───────────────────────────────────┘
┌────── Iteration 2 (frames 81-177) ──────┐
│ Processing window: 97 frames │
│ Overlap: 17 frames (81-97) from prev │ → 5 blocks, async processing
│ Generates: frames 98-177 │
└─────────────────────────────────────────┘
┌────── Iteration 3 (frames 161-257) ──────┐
│ Processing window: 97 frames │
│ Overlap: 17 frames (161-177) from prev │ → 5 blocks, async processing
│ Generates: frames 178-257 │
└──────────────────────────────────────────┘
Each iteration independently runs the asynchronous processing with its own 5 blocks.
base_num_frames controls:
1. Memory usage (larger window = more VRAM)
2. Model context length (must match training constraints)
3. Number of blocks per iteration (base_num_latent_frames // causal_block_size)
Each block takes 30 steps to complete denoising.
Block N starts at step: 1 + (N-1) x ar_step
Total steps: 30 + (5-1) x 5 = 50 steps
Synchronous mode (ar_step=0) would process all blocks/frames simultaneously:
┌──────────────────────────────────────────────┐
│ Steps: 1 ... 30 │
│ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
└──────────────────────────────────────────────┘
Total steps: 30 steps
An example on how the step matrix is constructed for asynchronous processing:
Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5)
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
The algorithm creates a 50x25 step_matrix where:
- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999]
- ...
- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901]
- ...
- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773]
- ...
- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216]
Detailed Row 6 Analysis:
- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999]
- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0]
- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False]
- valid_interval[5]: (0, 25)
Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
### Text-to-Video Generation
The example below demonstrates how to generate a video from text.
<hfoptions id="T2V usage">
<hfoption id="T2V memory">
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
From the original repo:
>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
vae=vae,
transformer=transformer,
torch_dtype=torch.bfloat16
)
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline = pipeline.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipeline(
prompt=prompt,
num_inference_steps=30,
height=544, # 720 for 720P
width=960, # 1280 for 720P
num_frames=97,
base_num_frames=97, # 121 for 720P
ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
causal_block_size=5, # Number of frames in each block for asynchronous processing
overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
addnoise_condition=20, # Improves consistency in long video generation
).frames[0]
export_to_video(output, "T2V.mp4", fps=24, quality=8)
```
</hfoption>
</hfoptions>
### First-Last-Frame-to-Video Generation
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
<hfoptions id="FLF2V usage">
<hfoption id="usage">
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipeline, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipeline)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
```
</hfoption>
</hfoptions>
### Video-to-Video Generation
<hfoptions id="V2V usage">
<hfoption id="usage">
`SkyReelsV2DiffusionForcingVideoToVideoPipeline` extends a given video.
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_video
model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
video = load_video("input_video.mp4")
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
# Total frames will be the number of frames of given video + 257
```
</hfoption>
</hfoptions>
## Notes
- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
<details>
<summary>Show example code</summary>
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
pipeline.set_adapters("steamboat-willie")
pipeline.enable_model_cpu_offload()
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
output = pipeline(
prompt=prompt,
num_frames=97,
guidance_scale=6.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24)
```
</details>
## SkyReelsV2DiffusionForcingPipeline
[[autodoc]] SkyReelsV2DiffusionForcingPipeline
- all
- __call__
## SkyReelsV2DiffusionForcingImageToVideoPipeline
[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline
- all
- __call__
## SkyReelsV2DiffusionForcingVideoToVideoPipeline
[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline
- all
- __call__
## SkyReelsV2Pipeline
[[autodoc]] SkyReelsV2Pipeline
- all
- __call__
## SkyReelsV2ImageToVideoPipeline
[[autodoc]] SkyReelsV2ImageToVideoPipeline
- all
- __call__
## SkyReelsV2PipelineOutput
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
@@ -1,316 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`AutoPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that automatically selects which sub-blocks to run based on the inputs provided at runtime, creating conditional workflows that adapt to different scenarios. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use.
In this tutorial, we will show you how to create an `AutoPipelineBlocks` and learn more about how the conditional selection works.
<Tip>
Other types of multi-blocks include [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `AutoPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided.
Let's see an example. We'll use the helper function from the [PipelineBlock guide](./pipeline_block.md) to create our blocks:
**Helper Function**
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
Now let's create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines.
```py
from diffusers.modular_pipelines import AutoPipelineBlocks
# These are dummy blocks and we only focus on "inputs" for our purpose
inputs = [InputParam(name="prompt")]
# block_fn prints out which workflow is running so we can see the execution order at runtime
block_fn = lambda x, y: print("running the text-to-image workflow")
block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!")
inputs = [InputParam(name="prompt"), InputParam(name="image")]
block_fn = lambda x, y: print("running the image-to-image workflow")
block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!")
inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
block_fn = lambda x, y: print("running the inpaint workflow")
block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!")
class AutoImageBlocks(AutoPipelineBlocks):
# List of sub-block classes to choose from
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
# Names for each block in the same order
block_names = ["inpaint", "img2img", "text2img"]
# Trigger inputs that determine which block to run
# - "mask" triggers inpaint workflow
# - "image" triggers img2img workflow (but only if mask is not provided)
# - if none of above, runs the text2img workflow (default)
block_trigger_inputs = ["mask", "image", None]
# Description is extremely important for AutoPipelineBlocks
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ " - inpaint workflow is run when `mask` is provided.\n"
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
)
# Create the blocks
auto_blocks = AutoImageBlocks()
# convert to pipeline
auto_pipeline = auto_blocks.init_pipeline()
```
Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime.
```py
AutoImageBlocks(
Class: AutoPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['mask', 'image']
====================================================================================================
Description: Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
Sub-Blocks:
inpaint [trigger: mask] (TestBlock)
Description: I'm a inpaint workflow!
img2img [trigger: image] (TestBlock)
Description: I'm a image-to-image workflow!
text2img [default] (TestBlock)
Description: I'm a text-to-image workflow!
)
```
Check out the documentation with `print(auto_pipeline.doc)`:
```py
>>> print(auto_pipeline.doc)
class AutoImageBlocks
Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
Inputs:
prompt (`None`, *optional*):
image (`None`, *optional*):
mask (`None`, *optional*):
```
There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right?
This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks.
Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running:
```py
>>> _ = auto_pipeline(image="image", mask="mask")
running the inpaint workflow
>>> _ = auto_pipeline(image="image")
running the image-to-image workflow
>>> _ = auto_pipeline(prompt="prompt")
running the text-to-image workflow
>>> _ = auto_pipeline(image="prompt", mask="mask")
running the inpaint workflow
```
However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines.
Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
inputs = [InputParam(name="ip_adapter_image")]
block_fn = lambda x, y: print("running the ip-adapter workflow")
block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!")
class AutoIPAdapter(AutoPipelineBlocks):
block_classes = [block_ipa_cls]
block_names = ["ip-adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."
```
Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`:
```py
auto_ipa_blocks = AutoIPAdapter()
blocks_dict = InsertableDict()
blocks_dict["ip-adapter"] = auto_ipa_blocks
blocks_dict["image-generation"] = auto_blocks
all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
pipeline = all_blocks.init_pipeline()
```
Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases.
```py
>>> all_blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['image', 'mask', 'ip_adapter_image']
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`).
====================================================================================================
Description:
Sub-Blocks:
[0] ip-adapter (AutoIPAdapter)
Description: Run IP Adapter step if `ip_adapter_image` is provided.
[1] image-generation (AutoImageBlocks)
Description: Pipeline generates images given different types of conditions!
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
- inpaint workflow is run when `mask` is provided.
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
- text2img workflow is run when neither `image` nor `mask` is provided.
)
```
This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs.
Let's try some examples:
`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block.
```py
>>> all_blocks.get_execution_blocks('mask')
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Sub-Blocks:
[0] image-generation (TestBlock)
Description: I'm a inpaint workflow!
)
```
Let's also actually run the pipeline to confirm:
```py
>>> _ = pipeline(mask="mask")
skipping auto block: AutoIPAdapter
running the inpaint workflow
```
Try a few more:
```py
print(f"inputs: ip_adapter_image:")
blocks_select = all_blocks.get_execution_blocks('ip_adapter_image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt")
# expect to see ip-adapter + text2img
print(f"inputs: image:")
blocks_select = all_blocks.get_execution_blocks('image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(image="image", prompt="prompt")
# expect to see img2img
print(f"inputs: prompt:")
blocks_select = all_blocks.get_execution_blocks('prompt')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(prompt="prompt")
# expect to see text2img (prompt is not a trigger input so fallback to default)
print(f"inputs: mask + ip_adapter_image:")
blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image')
print(f"expected_execution_blocks: {blocks_select}")
print(f"actual execution blocks:")
_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image")
# expect to see ip-adapter + inpaint
```
In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions.
@@ -1,514 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Components Manager
<Tip warning={true}>
🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows.
## Basic Operations
Let's start with the most basic operations. First, create a Components Manager:
```py
from diffusers import ComponentsManager
comp = ComponentsManager()
```
Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function):
```py
from diffusers import AutoModel
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
# Returns component_id like 'text_encoder_139917733042864'
component_id = comp.add("text_encoder", text_encoder)
```
You can view all registered components and their metadata:
```py
>>> comp
Components:
===============================================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-----------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
And remove components using their unique ID:
```py
comp.remove("text_encoder_139917733042864")
```
## Duplicate Detection
The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice.
When you try to add the same object twice, the manager will warn you and return the existing ID:
```py
>>> comp.add("text_encoder", text_encoder)
'text_encoder_139917733042864'
>>> comp.add("text_encoder", text_encoder)
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864'
'text_encoder_139917733042864'
```
Even if you add the same object under a different name, it will still be detected as a duplicate:
```py
>>> comp.add("clip", text_encoder)
ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864'
To remove a duplicate, call `components_manager.remove('<component_id>')`.
'clip_139917733042864'
```
However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example:
```py
>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
>>> comp.add("text_encoder", text_encoder_2)
'text_encoder_139917732983664'
```
This creates a problem - you now have two copies of the same model consuming double the memory:
```py
>>> comp
Components:
===============================================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-----------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint:
```py
from diffusers import ComponentSpec, ComponentsManager
from transformers import CLIPTextModel
comp = ComponentsManager()
# Create ComponentSpec for the first text encoder
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder)
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
# Load and add both components - the manager will detect they're the same model
comp.add("text_encoder", spec.load())
comp.add("text_encoder_duplicated", spec_duplicated.load())
```
Now the manager detects the duplicate and warns you:
```out
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
'text_encoder_duplicated_139917580682672'
```
Both models now show the same `load_id`, making it clear they're the same model:
```py
>>> comp
Components:
======================================================================================================================================================================================================
Models:
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
## Collections
Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed.
Here's how collections work in practice:
```py
comp = ComponentsManager()
# Create ComponentSpec for the first UNet (SDXL base)
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
# Create ComponentSpec for a different UNet (Juggernaut-XL)
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
# Add both UNets to the same collection - the second one will replace the first
comp.add("unet", spec.load(), collection="sdxl")
comp.add("unet", spec2.load(), collection="sdxl")
```
The manager automatically removes the old UNet and adds the new one:
```out
ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888
'unet_139917723893136'
```
Only one UNet remains in the collection:
```py
>>> comp
Components:
====================================================================================================================================================================
Models:
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed.
## Retrieving Components
The Components Manager provides several methods to retrieve registered components.
The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use:
- exact matches like `comp.get_one(name="unet")`
- wildcards like `comp.get_one(name="unet*")` for components starting with "unet"
- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet"
- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae".
Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error.
Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`.
```py
# Get components by name list
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
# Returns: {"text_encoder": component1, "unet": component2, "vae": component3}
```
## Using Components Manager with Modular Pipelines
The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter:
```py
from diffusers import ModularPipeline, ComponentsManager
comp = ComponentsManager()
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
```
By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty:
```py
>>> comp
Components:
==================================================
No components registered.
==================================================
```
When you load components on the pipeline, they are automatically registered in the Components Manager:
```py
>>> pipe.load_components(names="unet")
>>> comp
Components:
==============================================================================================================================================================
Models:
--------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
--------------------------------------------------------------------------------------------------------------------------------------------------------------
unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
--------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection:
```py
# Load all default components
>>> pipe.load_default_components()
# Create a second pipeline using the same Components Manager but with a different collection
>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
```
As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager:
```py
# Get the list of components that pipe2 needs to load
>>> pipe2.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
# Retrieve all required components from the Components Manager
>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
# Update the pipeline with the retrieved components
>>> pipe2.update_components(**comp_dict)
```
The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates:
```out
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400'
ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424'
ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504'
ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904'
ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304'
ComponentsManager: component 'unet' already exists as 'unet_139917580609632'
ComponentsManager: component 'vae' already exists as 'vae_139917722459040'
ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408'
ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432'
```
The pipeline is now fully loaded:
```py
# null_component_names return empty list, meaning everything are loaded
>>> pipe2.null_component_names
[]
```
No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines:
```py
>>> comp
Components:
========================================================================================================================================================================================
Models:
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1
| | | | | | test2
text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1
| | | | | | test2
unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
| | | | | | test2
controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1
| | | | | | test2
vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1
| | | | | | test2
image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1
| | | | | | test2
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Other Components:
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
ID | Class | Collection
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
tokenizer_139917580599504 | CLIPTokenizer | test1
| | test2
scheduler_139916266559408 | EulerDiscreteScheduler | test1
| | test2
tokenizer_2_139915763443904 | CLIPTokenizer | test1
| | test2
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
```
## Automatic Memory Management
The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them:
```py
comp.enable_auto_cpu_offload(device="cuda")
```
When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow.
## Practical Example: Building Modular Workflows with Component Reuse
Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components.
In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline.
Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images.
```py
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
# Create modular blocks and separate text encoding and decoding steps
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
```
Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection
Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows.
```py
from diffusers import ComponentsManager, ModularPipeline
# Set up Components Manager with auto offloading
components = ComponentsManager()
components.enable_auto_cpu_offload(device="cuda")
# Create a new pipeline to load the components
t2i_repo = "YiYiXu/modular-demo-auto"
t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
# convert the 3 blocks into pipelines and attach the same components manager to all 3
text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
```
Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection:
```py
# Load all components (including IP-Adapter and ControlNet for later use)
t2i_loader_pipe.load_default_components(torch_dtype=torch.float16)
```
Now distribute the loaded components to each pipeline:
```py
# Get VAE for decoder (using get_one since there's only one)
vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
decoder_node.update_components(vae=vae)
# Get text components for text node (using get_components_by_names for multiple components)
text_components = components.get_components_by_names(text_node.null_component_names)
text_node.update_components(**text_components)
# Get remaining components for t2i pipeline
t2i_components = components.get_components_by_names(t2i_pipe.null_component_names)
t2i_pipe.update_components(**t2i_components)
```
Now we can generate images using our modular workflow:
```py
# Generate text embeddings
prompt = "an astronaut"
text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
# Generate latents and decode to image
generator = torch.Generator(device="cuda").manual_seed(0)
latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_t2i, output="images")[0]
image.save("modular_part2_t2i.png")
```
Let's add a LoRA:
```py
# Load LoRA weights
>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
>>> components
Components:
============================================================================================================================================================
...
Additional Component Info:
==================================================
unet:
Adapters: ['toy_face']
```
You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings.
```py
# Generate with LoRA (reusing existing text embeddings)
generator = torch.Generator(device="cuda").manual_seed(0)
latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_lora, output="images")[0]
image.save("modular_part2_lora.png")
```
Now let's create a refiner pipeline that reuses components from our text-to-image workflow:
```py
# Create refiner blocks (removing image_encoder and decode since we work with latents)
refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
refiner_blocks.sub_blocks.pop("image_encoder")
refiner_blocks.sub_blocks.pop("decode")
# Create refiner pipeline with different repo and collection,
# Attach the same component manager to it
refiner_repo = "YiYiXu/modular_refiner"
refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
```
We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection.
```py
# Load only the refiner UNet (different from t2i UNet)
refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
# Reuse components from t2i pipeline using pattern matching
reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
refiner_pipe.update_components(**reuse_components)
```
When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows.
Now we can refine any of our generated latents:
```py
# Refine all our different latents
refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_t2i_refine_out.png")
refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_lora_refine_out.png")
```
Here are the results from our modular pipeline examples.
#### Base Text-to-Image Generation
| Base Text-to-Image | Base Text-to-Image (Refined) |
|-------------------|------------------------------|
| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) |
#### LoRA
| LoRA | LoRA (Refined) |
|-------------------|------------------------------|
| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) |
@@ -1,648 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# End-to-End Developer Guide: Building with Modular Diffusers
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
1. **Start with an existing pipeline as a base**
- Identify which existing pipeline is most similar to the one you want to implement
- Determine what part of the pipeline needs modification
2. **Build a working pipeline structure first**
- Assemble the complete pipeline structure
- Use existing blocks wherever possible
- For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
3. **Set up an example**
- Create a simple inference script with expected inputs/outputs
4. **Implement your custom logic and test incrementally**
- Add the custom logics the blocks you want to change
- Test incrementally, and inspect pipeline states and debug as needed
Let's see how this works with the Differential Diffusion example.
## Differential Diffusion Pipeline
### Start with an existing pipeline
Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
>>> IMAGE2IMAGE_BLOCKS = InsertableDict([
... ("text_encoder", StableDiffusionXLTextEncoderStep),
... ("image_encoder", StableDiffusionXLVaeEncoderStep),
... ("input", StableDiffusionXLInputStep),
... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
... ("denoise", StableDiffusionXLDenoiseStep),
... ("decode", StableDiffusionXLDecodeStep)
... ])
```
Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
```py
>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
>>> print(denoise_blocks)
```
```out
StableDiffusionXLDenoiseStep(
Class: StableDiffusionXLDenoiseLoopWrapper
Description: Denoise step that iteratively denoise the latents.
Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
At each iteration, it runs blocks defined in `sub_blocks` sequencially:
- `StableDiffusionXLLoopBeforeDenoiser`
- `StableDiffusionXLLoopDenoiser`
- `StableDiffusionXLLoopAfterDenoiser`
This block supports both text2img and img2img tasks.
Components:
scheduler (`EulerDiscreteScheduler`)
guider (`ClassifierFreeGuidance`)
unet (`UNet2DConditionModel`)
Sub-Blocks:
[0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
[1] denoiser (StableDiffusionXLLoopDenoiser)
Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
[2] after_denoiser (StableDiffusionXLLoopAfterDenoiser)
Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
)
```
Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
Therefore, the key differences when it comes to pipeline implementation would be:
1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
2. The `denoise` step (which selectively applies denoising based on the change map)
3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
Here's a flowchart showing the pipeline structure and the changes we need to make:
![DiffDiff Pipeline Structure](https://mermaid.ink/img/pako:eNqVVO9r4kAQ_VeWLQWFKEk00eRDwZpa7Q-ucPfpYpE1mdWlcTdsVmpb-7_fZk1tTCl3J0Sy8968N5kZ9g0nIgUc4pUk-Rr9iuYc6d_Ibs14vlXoQYpNrtqo07lAo1jBTi2AlynysWIa6DJmG7KCBnZpsHHMSqkqNjaxKC5ALRTbQKEgLyosMthVnEvIiYRFRhRwVaBoNpmUT0W7MrTJkUbSdJEInlbwxMDXcQpcsAKq6OH_2mDTODIY4yt0J0ReUaYGnLXiJVChdSsB-enfPhBnhnjT-rCQj-1K_8Ygt62YUAVy8Ykf4FvU6XYu9rpuIGqPpvXSzs_RVEj2KrgiGUp02zNQTHBEM_FcK3BfQbBHd7qAst-PxvW-9WOrypnNylG0G9oRUMYBFeolg-IQTTJSFDqOUkZp-fwsQURZloVnlPpLf2kVSoonCM-SwCUuqY6dZ5aqddjLd1YiMiFLNrWorrxj9EOmP4El37lsl_9p5PzFqIqwVwgdN981fDM94bphH5I06R8NXZ_4QcPQPTFs6JltPrS6JssFhw9N817l27bdyM-lSKAo6iVBAAnQY0n9wLO9wbcluY7ruUFDtdguH74K0yENKDkK-8nAG6TfNrfy_bf-HjdrlOfZS7VYSAlU5JAwyhLE9WrWVw1dWdPTXauDsy8LUkdHtnX_pfMnBOvSGluRNbGurbuTHtdZN9Zts1MljC19_7EUh0puwcIbkBtSHvFbic6xWsMG5jjUrymRT3M85-86Jyf8txCbjzQptqs1DinJCn3a5qm-viJG9M26OUYlcH0_jsWWKxwGttHA4Rve4dD1el3H8_yh49hD3_X7roVfcNhx-l3b14PxvGHQ0xMa9t4t_Gp8na7tDvu-4w08HXecweD9D4X54ZI)
### Build a Working Pipeline Structure
ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
```py
>>> # Copy existing blocks as placeholders
>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
...
>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later"""
... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
```
`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
```py
>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
```
Now we can put together our differential diffusion pipeline.
```py
>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
>>>
>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
>>> print(dd_blocks)
>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
```
### Set up an example
ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
```py
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
>>> dd_pipeline.to("cuda")
```
We will use this example script:
```py
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
>>>
>>> image.save("diffdiff_out.png")
```
If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
and you would get the same result as the original img2img pipeline.
### implement your custom logic and test incrementally
Let's modify the pipeline so that we can get expected result with this example script.
We'll start with the `prepare_latents` step. The main changes are:
- Requires a new user input `diffdiff_map`
- Requires new component `mask_processor` to process the `diffdiff_map`
- Requires new intermediate inputs:
- Need `timestep` instead of `latent_timestep` to precompute all the latents
- Need `num_inference_steps` to create the `diffdiff_masks`
- create a new output `diffdiff_masks` and `original_latents`
<Tip>
💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline.
e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input.
</Tip>
Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
<Tip>
💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does.
</Tip>
Here are the key changes we made to implement differential diffusion:
**1. Modified `prepare_latents` step:**
```diff
class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
+ InputParam("diffdiff_map", required=True),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
+ InputParam("timesteps", type_hint=torch.Tensor),
+ InputParam("num_inference_steps", type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
+ OutputParam("original_latents", type_hint=torch.Tensor),
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, state: PipelineState):
# ... existing logic ...
+ # Process change map and create masks
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
+ block_state.original_latents = block_state.latents
```
**2. Modified `before_denoiser` step:**
```diff
class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
@property
def description(self) -> str:
return (
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
)
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("denoising_start"),
+ ]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
+ InputParam("original_latents", type_hint=torch.Tensor),
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, block_state, i, t):
+ # Apply differential diffusion logic
+ if i == 0 and block_state.denoising_start is None:
+ block_state.latents = block_state.original_latents[:1]
+ else:
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
# ... rest of existing logic ...
```
That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
Now we use the process we've prepred in step2 to build the pipeline and inspect it.
```py
>> dd_pipeline
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Components:
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
guider (`ClassifierFreeGuidance`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
scheduler (`EulerDiscreteScheduler`)
mask_processor (`VaeImageProcessor`)
unet (`UNet2DConditionModel`)
Configs:
force_zeros_for_empty_prompt (default: True)
requires_aesthetics_score (default: False)
Blocks:
[0] text_encoder (StableDiffusionXLTextEncoderStep)
Description: Text Encoder step that generate text_embeddings to guide the image generation
[1] image_encoder (StableDiffusionXLVaeEncoderStep)
Description: Vae Encoder step that encode the input image into a latent representation
[2] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[3] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[4] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
Description: Step that prepares the latents for the differential diffusion generation process
[5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
[6] denoise (SDXLDiffDiffDenoiseStep)
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
[7] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
```
Run the example now, you should see an apple with its right half transformed into a green pear.
![Image description](https://cdn-uploads.huggingface.co/production/uploads/624ef9ba9d608e459387b34e/4zqJOz-35Q0i6jyUW3liL.png)
## Adding IP-adapter
We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md)
We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
```
We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
```py
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
```
Take a look at the new diff-diff pipeline with ip-adapter!
```py
>>> print(dd_blocks)
```
The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor`
```out
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: {'ip_adapter_image'}
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
====================================================================================================
Description:
Components:
image_encoder (`CLIPVisionModelWithProjection`)
feature_extractor (`CLIPImageProcessor`)
unet (`UNet2DConditionModel`)
guider (`ClassifierFreeGuidance`)
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
scheduler (`EulerDiscreteScheduler`)
mask_processor (`VaeImageProcessor`)
Configs:
force_zeros_for_empty_prompt (default: True)
requires_aesthetics_score (default: False)
Blocks:
[0] ip_adapter (StableDiffusionXLAutoIPAdapterStep)
Description: Run IP Adapter step if `ip_adapter_image` is provided.
[1] text_encoder (StableDiffusionXLTextEncoderStep)
Description: Text Encoder step that generate text_embeddings to guide the image generation
[2] image_encoder (StableDiffusionXLVaeEncoderStep)
Description: Vae Encoder step that encode the input image into a latent representation
[3] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[4] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[5] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
Description: Step that prepares the latents for the differential diffusion generation process
[6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
[7] denoise (SDXLDiffDiffDenoiseStep)
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
[8] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
```
Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output.
```py
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
>>>
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
>>> dd_pipeline.loader.set_ip_adapter_scale(0.6)
>>> dd_pipeline = dd_pipeline.to(device)
>>>
>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>> generator = torch.Generator(device=device).manual_seed(42)
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... generator=generator,
... ip_adapter_image=ip_adapter_image,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
```
## Working with ControlNets
What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are:
1. A ControlNet input step that prepares the control condition
2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet
From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts.
Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step.
```py
>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
>>>
>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
>>> # print(controlnet_denoise)
```
We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union.
```py
>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
>>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
>>> print(control_input_block)
```
```out
StableDiffusionXLAutoControlNetInputStep(
Class: AutoPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: ['control_image', 'control_mode']
====================================================================================================
Description: Controlnet Input step that prepare the controlnet input.
This is an auto pipeline block that works for both controlnet and controlnet_union.
(it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
Components:
controlnet (`ControlNetUnionModel`)
control_image_processor (`VaeImageProcessor`)
Sub-Blocks:
• controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
Description: step that prepares inputs for the ControlNetUnion model
• controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep)
Description: step that prepare inputs for controlnet
)
```
Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
```py
>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
>>>
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
>>> dd_pipeline = dd_pipeline.to(device)
>>>
>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
>>>
>>> prompt = "a green pear"
>>> negative_prompt = "blurry"
>>> generator = torch.Generator(device=device).manual_seed(42)
>>>
>>> image = dd_pipeline(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=25,
... generator=generator,
... control_image=control_image,
... controlnet_conditioning_scale=0.5,
... diffdiff_map=mask,
... image=image,
... output="images"
... )[0]
```
Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
```py
>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
... block_names = ["controlnet_denoise", "denoise"]
... block_trigger_inputs = ["controlnet_cond", None]
```
`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step.
<Tip>
Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
</Tip>
Now you can create the differential diffusion preset that works with ip-adapter & controlnet.
```py
>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
>>>
>>> print(DIFFDIFF_AUTO_BLOCKS)
```
to use
```py
>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
>>> dd_pipeline = dd_auto_blocks.init_pipeline(...)
```
## Creating a Modular Repo
You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff
To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily.
```py
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
```
With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared.
```py
>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
>>> import torch
>>> from diffusers.utils import load_image
>>>
>>> repo_id = "YiYiXu/modular-diffdiff-0704"
>>>
>>> components = ComponentsManager()
>>>
>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff")
>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
>>> components.enable_auto_cpu_offload()
```
see more usage example on model card.
## deploy a mellon node
[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon]
@@ -1,194 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LoopSequentialPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks.
<Tip>
Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
You could create a loop using `PipelineBlock` like this:
```python
class DenoiseLoop(PipelineBlock):
def __call__(self, components, state):
block_state = self.get_block_state(state)
for t in range(block_state.num_inference_steps):
# ... loop logic here
pass
self.set_block_state(state, block_state)
return components, state
```
But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations.
It involves two parts: a **loop wrapper** and **loop blocks**
* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar.
* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper.
- they run sequentially for each iteration of the loop
- they receive the current iteration index as an additional parameter
- they share the same block_state throughout the entire loop
Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations.
We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps:
1. defining the loop wrapper class
2. creating the loop blocks
3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance
**Step 1: Define the Loop Wrapper**
To create a `LoopSequentialPipelineBlocks` class, you need to define:
* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`)
* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`)
* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`)
* `__call__` method: Defines the loop structure and iteration logic
Here is an example of a loop wrapper:
```py
import torch
from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam
class LoopWrapper(LoopSequentialPipelineBlocks):
model_name = "test"
@property
def description(self):
return "I'm a loop!!"
@property
def loop_inputs(self):
return [InputParam(name="num_steps")]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
# Loop structure - can be customized to your needs
for i in range(block_state.num_steps):
# loop_step executes all registered blocks in sequence
components, block_state = self.loop_step(components, block_state, i=i)
self.set_block_state(state, block_state)
return components, state
```
**Step 2: Create Loop Blocks**
Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently:
* It receives the iteration variable (e.g., `i`) passed by the loop wrapper
* It works directly with `block_state` instead of pipeline state
* No need to call `self.get_block_state()` or `self.set_block_state()`
```py
class LoopBlock(PipelineBlock):
# this is used to identify the model family, we won't worry about it in this example
model_name = "test"
@property
def inputs(self):
return [InputParam(name="x")]
@property
def intermediate_outputs(self):
# outputs produced by this block
return [OutputParam(name="x")]
@property
def description(self):
return "I'm a block used inside the `LoopWrapper` class"
def __call__(self, components, block_state, i: int):
block_state.x += 1
return components, block_state
```
**Step 3: Combine Everything**
Finally, assemble your loop by adding the block(s) to the wrapper:
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
```
Now you've created a loop with one step:
```py
>>> loop
LoopWrapper(
Class: LoopSequentialPipelineBlocks
Description: I'm a loop!!
Sub-Blocks:
[0] block1 (LoopBlock)
Description: I'm a block used inside the `LoopWrapper` class
)
```
It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop.
```py
>>> print(loop.doc)
class LoopWrapper
I'm a loop!!
Inputs:
x (`None`, *optional*):
num_steps (`None`, *optional*):
Outputs:
x (`None`):
```
**Running the Loop:**
```py
# run the loop
loop_pipeline = loop.init_pipeline()
x = loop_pipeline(num_steps=10, x=0, output="x")
assert x == 10
```
**Adding Multiple Blocks:**
We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration:
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
loop_pipeline = loop.init_pipeline()
x = loop_pipeline(num_steps=10, x=0, output="x")
assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20
```
**Key Differences from SequentialPipelineBlocks:**
The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself.
The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice:
```py
from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep
StableDiffusionXLDenoiseStep()
```
@@ -1,59 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PipelineState and BlockState
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system.
In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
A `PipelineState` consists of two distinct states:
- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it.
- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them.
Here's an example of what a `PipelineState` looks like:
```py
PipelineState(
inputs={
'prompt': 'a cat'
'guidance_scale': 7.0
'num_inference_steps': 25
},
intermediates={
'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
'negative_prompt_embeds': None
},
)
```
Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes.
For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain:
```py
BlockState(
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>
)
```
You can access the variables directly as attributes: `block_state.image`.
We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md).
File diff suppressed because it is too large Load Diff
@@ -1,42 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Getting Started with Modular Diffusers
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you:
**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.
**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows.
Here's how our guides are organized to help you navigate the Modular Diffusers documentation:
### 🚀 Running Pipelines
- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it
- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines
### 📚 Creating PipelineBlocks
- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState
- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks
- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence
- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows
- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection
### 🎯 Practical Examples
- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes
@@ -1,292 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PipelineBlock
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows.
<Tip>
**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md).
</Tip>
In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state.
## PipelineState
Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details.
## Define a `PipelineBlock`
To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce.
The three main properties you need to define are:
- `inputs`: Immutable values from the user that cannot be modified
- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified
- `intermediate_outputs`: New values your block creates for subsequent blocks and user access
Let's explore each one and understand how they work with the pipeline state.
**Inputs: Immutable User Values**
Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`:
```py
user_inputs = [
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
]
```
When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged."
This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency.
**Intermediate Inputs: Mutable Values from Previous Blocks, or Users**
Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case:
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different."
**Intermediate Outputs: New Values for Subsequent Blocks and User Access**
Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes:
1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline
2. **For users**: They become available as final outputs that users can access when running the pipeline
```py
user_intermediate_outputs = [
OutputParam(name="image_latents", description="latents representing the image")
]
```
Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match.
Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps.
**The `__call__` Method Structure**
Your `PipelineBlock`'s `__call__` method should follow this structure:
```py
def __call__(self, components, state):
# Get a local view of the state variables this block needs
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# You can access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
self.set_block_state(state, block_state)
return components, state
```
The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation.
**Components and Configs**
You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`:
```py
from diffusers import ComponentSpec, ConfigSpec
# Define components your block needs
expected_components = [
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
]
# Define pipeline-level configs
expected_config = [
ConfigSpec("force_zeros_for_empty_prompt", True)
]
```
**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md).
**Configs**: Pipeline-level settings that control behavior across all blocks.
When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation:
```py
def __call__(self, components, state):
# Access components using dot notation
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks
Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that.
**Helper Function**
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
## Example: Creating a Simple Pipeline Block
Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them:
```py
inputs = [
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
]
intermediate_inputs = [InputParam(name="batch_size", type_hint=int)]
intermediate_outputs = [
OutputParam(name="image_latents", description="latents representing the image")
]
def image_encoder_block_fn(block_state, pipeline_state):
print(f"pipeline_state (before update): {pipeline_state}")
print(f"block_state (before update): {block_state}")
# Simulate processing the image
block_state.image = torch.randn(1, 3, 512, 512)
block_state.batch_size = block_state.batch_size * 2
block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size
block_state.image_latents = torch.randn(1, 4, 64, 64)
print(f"block_state (after update): {block_state}")
return block_state
# Create a block with our definitions
image_encoder_block_cls = make_block(
inputs=inputs,
intermediate_inputs=intermediate_inputs,
intermediate_outputs=intermediate_outputs,
block_fn=image_encoder_block_fn,
description="Encode raw image into its latent presentation"
)
image_encoder_block = image_encoder_block_cls()
pipe = image_encoder_block.init_pipeline()
```
Let's check the pipeline's docstring to see what inputs it expects:
```py
>>> print(pipe.doc)
class TestBlock
Encode raw image into its latent presentation
Inputs:
image (`PIL.Image`, *optional*):
raw input image to process
batch_size (`int`, *optional*):
Outputs:
image_latents (`None`):
latents representing the image
```
Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state.
Now let's run the pipeline:
```py
from diffusers.utils import load_image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png")
state = pipe(image=image, batch_size=2)
print(f"pipeline_state (after update): {state}")
```
```out
pipeline_state (before update): PipelineState(
inputs={
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
},
intermediates={
batch_size: 2
},
)
block_state (before update): BlockState(
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>
batch_size: 2
)
block_state (after update): BlockState(
image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512]))
batch_size: 4
processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])]
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
)
pipeline_state (after update): PipelineState(
inputs={
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
},
intermediates={
batch_size: 4
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
},
)
```
**Key Observations:**
1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`.
2. **After the update**:
- **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only.
- **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict)
- **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output
- **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output
@@ -1,189 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SequentialPipelineBlocks
<Tip warning={true}>
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
</Tip>
`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks.
<Tip>
Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
</Tip>
In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together.
The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks.
Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks.
```py
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(PipelineBlock):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_inputs(self):
return intermediate_inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
```
Let's create a block that produces `batch_size`, which we'll call "input_block":
```py
def input_block_fn(block_state, pipeline_state):
batch_size = len(block_state.prompt)
block_state.batch_size = batch_size * block_state.num_images_per_prompt
return block_state
input_block_cls = make_block(
inputs=[
InputParam(name="prompt", type_hint=list, description="list of text prompts"),
InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt")
],
intermediate_outputs=[
OutputParam(name="batch_size", description="calculated batch size")
],
block_fn=input_block_fn,
description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
)
input_block = input_block_cls()
```
Now let's create a second block that uses the `batch_size` from the first block:
```py
def image_encoder_block_fn(block_state, pipeline_state):
# Simulate processing the image
block_state.image = torch.randn(1, 3, 512, 512)
block_state.batch_size = block_state.batch_size * 2
block_state.image_latents = torch.randn(1, 4, 64, 64)
return block_state
image_encoder_block_cls = make_block(
inputs=[
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
],
intermediate_inputs=[
InputParam(name="batch_size", type_hint=int)
],
intermediate_outputs=[
OutputParam(name="image_latents", description="latents representing the image")
],
block_fn=image_encoder_block_fn,
description="Encode raw image into its latent presentation"
)
image_encoder_block = image_encoder_block_cls()
```
Now let's connect these blocks to create a `SequentialPipelineBlocks`:
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
# Define a dict mapping block names to block instances
blocks_dict = InsertableDict()
blocks_dict["input"] = input_block
blocks_dict["image_encoder"] = image_encoder_block
# Create the SequentialPipelineBlocks
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
```
Now you have a `SequentialPipelineBlocks` with 2 blocks:
```py
>>> blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Sub-Blocks:
[0] input (TestBlock)
Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.
[1] image_encoder (TestBlock)
Description: Encode raw image into its latent presentation
)
```
When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it.
```py
>>> print(blocks.doc)
class SequentialPipelineBlocks
Inputs:
prompt (`None`, *optional*):
num_images_per_prompt (`None`, *optional*):
image (`PIL.Image`, *optional*):
raw input image to process
Outputs:
batch_size (`None`):
image_latents (`None`):
latents representing the image
```
At runtime, you have data flow like this:
![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png)
**How SequentialPipelineBlocks Works:**
1. Blocks are executed in the order they're registered in the `blocks_dict`
2. Outputs from one block become available as intermediate inputs to all subsequent blocks
3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks
4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces
What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs.
+14 -25
View File
@@ -174,36 +174,39 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 810x.
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# compile only the repeated transformer layers inside the UNet
pipeline.unet.compile_repeated_blocks(fullgraph=True)
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
```
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
> [!TIP]
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
# pip install -U accelerate
@@ -216,8 +219,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
### Graph breaks
@@ -239,12 +242,6 @@ The `step()` function is [called](https://github.com/huggingface/diffusers/blob/
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
<Tip>
Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
</Tip>
### Benchmarks
Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) dataset to see inference latency and memory usage data for compiled pipelines.
@@ -299,11 +296,3 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
## Resources
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
- Read the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post
to maximize performance when using `torch.compile`.
@@ -14,9 +14,6 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
> [!TIP]
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -28,7 +25,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
@@ -0,0 +1,23 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Overview
Welcome to 🧨 Diffusers! If you're new to diffusion models and generative AI, and want to learn more, then you've come to the right place. These beginner-friendly tutorials are designed to provide a gentle introduction to diffusion models and help you understand the library fundamentals - the core components and how 🧨 Diffusers is meant to be used.
You'll learn how to use a pipeline for inference to rapidly generate things, and then deconstruct that pipeline to really understand how to use the library as a modular toolbox for building your own diffusion systems. In the next lesson, you'll learn how to train your own diffusion model to generate what you want.
After completing the tutorials, you'll have gained the necessary skills to start exploring the library on your own and see how to use it for your own projects and applications.
Feel free to join our community on [Discord](https://discord.com/invite/JfAtkvEtRb) or the [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) to connect and collaborate with other users and developers!
Let's start diffusing! 🧨
@@ -0,0 +1,18 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Overview
The inference pipeline supports and enables a wide range of techniques that are divided into two categories:
* Pipeline functionality: these techniques modify the pipeline or extend it for other applications. For example, pipeline callbacks add new features to a pipeline and a pipeline can also be extended for distributed inference.
* Improve inference quality: these techniques increase the visual quality of the generated images. For example, you can enhance your prompts with GPT2 to create better images with lower effort.
@@ -971,7 +971,6 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
args,
instance_data_root,
instance_prompt,
class_prompt,
@@ -981,8 +980,10 @@ class DreamBoothDataset(Dataset):
class_num=None,
size=1024,
repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
@@ -1057,7 +1058,7 @@ class DreamBoothDataset(Dataset):
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
@@ -1074,11 +1075,11 @@ class DreamBoothDataset(Dataset):
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - self.size) / 2.0)))
x1 = max(0, int(round((image.width - self.size) / 2.0)))
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
image = train_transforms(image)
self.pixel_values.append(image)
@@ -1101,7 +1102,7 @@ class DreamBoothDataset(Dataset):
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=interpolation),
transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -1826,7 +1827,6 @@ def main(args):
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
args=args,
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,6 +1836,7 @@ def main(args):
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
+1 -46
View File
@@ -87,7 +87,6 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -5480,48 +5479,4 @@ edited_image.save("edited_image.png")
### Note
This model is trained on 512x512, so input size is better on 512x512.
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
Editing at Scale", many thanks to their contribution!
# Flux Kontext multiple images
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
This pipeline loads two reference images and generates a new image based on them.
```python
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
custom_pipeline="pipeline_flux_kontext_multiple_images",
)
pipe.to("cuda")
pikachu_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
cat_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
).convert("RGB")
prompts = [
"Pikachu and the cat are sitting together at a pizzeria table, enjoying a delicious pizza.",
]
images = pipe(
multiple_images=[(pikachu_image, cat_image)],
prompt=prompts,
guidance_scale=2.5,
generator=torch.Generator().manual_seed(42),
).images
images[0].save("pizzeria.png")
```
Editing at Scale", many thanks to their contribution!
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
controlnet_image = controlnet_image * vae.config.scaling_factor
control_block_res_samples = controlnet(
hidden_states=noisy_model_input,
@@ -1614,7 +1614,7 @@ def main(args):
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
@@ -58,7 +58,6 @@ from diffusers.training_utils import (
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
offload_models,
)
from diffusers.utils import (
check_min_version,
@@ -1365,34 +1364,43 @@ def main(args):
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
validation_embeddings = {}
if args.validation_prompt is not None:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1573,10 +1581,12 @@ def main(args):
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
if args.offload:
vae = vae.to(accelerator.device)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
if args.offload:
vae = vae.to("cpu")
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
+1 -1
View File
@@ -1,4 +1,4 @@
torch~=2.7.0
torch~=2.4.0
transformers==4.46.1
sentencepiece
aiohttp
+20 -28
View File
@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.6.1
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.12.14
aiohttp==3.10.10
# via -r requirements.in
aiosignal==1.4.0
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
# via pydantic
@@ -29,6 +29,7 @@ filelock==3.16.1
# huggingface-hub
# torch
# transformers
# triton
frozenlist==1.5.0
# via
# aiohttp
@@ -62,42 +63,36 @@ networkx==3.2.1
# via torch
numpy==2.0.2
# via transformers
nvidia-cublas-cu12==12.6.4.1
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==9.5.1.17
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.3.0.4
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusolver-cu12==11.7.1.2
# via torch
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nccl-cu12==2.26.2
# via torch
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvjitlink-cu12==12.9.86
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.1
# via
@@ -110,9 +105,7 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via
# aiohttp
# yarl
# via yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -144,7 +137,7 @@ sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
torch==2.7.0
torch==2.4.1
# via -r requirements.in
tqdm==4.66.5
# via
@@ -152,11 +145,10 @@ tqdm==4.66.5
# transformers
transformers==4.46.1
# via -r requirements.in
triton==3.3.0
triton==3.0.0
# via torch
typing-extensions==4.12.2
# via
# aiosignal
# anyio
# exceptiongroup
# fastapi
@@ -171,5 +163,5 @@ urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.18.3
yarl==1.16.0
# via aiohttp
-637
View File
@@ -1,637 +0,0 @@
import argparse
import os
import pathlib
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from diffusers import (
AutoencoderKLWan,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
"fps_projection.0": "fps_projection.net.0.proj",
"fps_projection.2": "fps_projection.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def load_sharded_safetensors(dir: pathlib.Path):
if "720P" in str(dir):
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
else:
file_paths = list(dir.glob("model*.safetensors"))
state_dict = {}
for path in file_paths:
state_dict.update(load_file(path))
return state_dict
def get_transformer_config(model_type: str) -> Dict[str, Any]:
if model_type == "SkyReels-V2-DF-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 12,
"inject_sample_info": True,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-DF-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-DF-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-DF-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-T2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-T2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-T2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-T2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "SkyReels-V2-I2V-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": 1536,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 12,
"inject_sample_info": False,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-I2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-I2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
},
}
elif model_type == "SkyReels-V2-FLF2V-1.3B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
"diffusers_config": {
"added_kv_proj_dim": 1536,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 12,
"inject_sample_info": False,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
elif model_type == "SkyReels-V2-FLF2V-14B-540P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
elif model_type == "SkyReels-V2-FLF2V-14B-720P":
config = {
"model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
"diffusers_config": {
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"inject_sample_info": False,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"image_dim": 1280,
"pos_embed_seq_len": 514,
},
}
return config
def convert_transformer(model_type: str):
config = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
if "1.3B" in model_type:
original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors"))
else:
os.makedirs(model_type, exist_ok=True)
model_dir = pathlib.Path(model_type)
if "720P" in model_type:
top_shard = 7 if "I2V" in model_type else 6
zeros = "0" * (4 if "I2V" or "T2V" in model_type else 3)
model_name = "diffusion_pytorch_model"
elif "540P" in model_type:
top_shard = 14 if "I2V" in model_type else 12
model_name = "model"
for i in range(1, top_shard + 1):
shard_path = f"{model_name}-{i:05d}-of-{zeros}{top_shard}.safetensors"
hf_hub_download(model_id, shard_path, local_dir=model_dir)
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
if "FLF2V" in model_type:
if (
hasattr(transformer.condition_embedder, "image_embedder")
and hasattr(transformer.condition_embedder.image_embedder, "pos_embed")
and transformer.condition_embedder.image_embedder.pos_embed is not None
):
pos_embed_shape = transformer.condition_embedder.image_embedder.pos_embed.shape
original_state_dict["condition_embedder.image_embedder.pos_embed"] = torch.zeros(pos_embed_shape)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae():
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
new_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])
# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
new_state_dict[key] = value
continue
# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key
new_state_dict[new_key] = value
# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
new_state_dict[new_key] = value
# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value
with init_empty_weights():
vae = AutoencoderKLWan()
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--dtype", default="fp32")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args.model_type).to(dtype=dtype)
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction",
num_train_timesteps=1000,
use_flow_sigmas=True,
)
if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = SkyReelsV2ImageToVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "T2V" in args.model_type:
pipe = SkyReelsV2Pipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
elif "DF" in args.model_type:
pipe = SkyReelsV2DiffusionForcingPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(
args.output_path,
safe_serialization=True,
max_shard_size="5GB",
# push_to_hub=True,
# repo_id=f"<place_holder>/{args.model_type}-Diffusers",
)
+1 -1
View File
@@ -110,7 +110,7 @@ _deps = [
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
"k-diffusion==0.0.12",
"k-diffusion>=0.0.12",
"torchsde",
"note_seq",
"librosa",
-76
View File
@@ -34,13 +34,10 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"modular_pipelines": [],
"pipelines": [],
"quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -133,29 +130,14 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
"TangentialClassifierFreeGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -163,7 +145,6 @@ else:
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
@@ -220,7 +201,6 @@ else:
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
"SkyReelsV2Transformer3DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
@@ -239,15 +219,6 @@ else:
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
]
)
_import_structure["modular_pipelines"].extend(
[
"ComponentsManager",
"ComponentSpec",
"ModularPipeline",
"ModularPipelineBlocks",
]
)
_import_structure["optimization"] = [
@@ -362,14 +333,6 @@ except OptionalDependencyNotAvailable:
]
else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
"WanModularPipeline",
]
)
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
@@ -493,11 +456,6 @@ else:
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingVideoToVideoPipeline",
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
@@ -587,7 +545,6 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -794,32 +751,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
@@ -876,7 +819,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
@@ -894,13 +836,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
)
from .modular_pipelines import (
ComponentsManager,
ComponentSpec,
ModularPipeline,
ModularPipelineBlocks,
)
from .optimization import (
get_constant_schedule,
@@ -998,12 +933,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,
WanModularPipeline,
)
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
@@ -1124,11 +1053,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
SkyReelsV2DiffusionForcingImageToVideoPipeline,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2DiffusionForcingVideoToVideoPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
-35
View File
@@ -207,38 +207,3 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
class SD3CFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
pooled_prompt_embeds = pooled_prompt_embeds[
-1:
] # "-1" denotes the embeddings for conditional pooled text tokens.
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
return callback_kwargs
-134
View File
@@ -1,134 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage example:
TODO
"""
import ast
import importlib.util
import os
from argparse import ArgumentParser, Namespace
from pathlib import Path
from ..utils import logging
from . import BaseDiffusersCLICommand
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
CONFIG = "config.json"
def conversion_command_factory(args: Namespace):
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
class CustomBlocksCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
conversion_parser = parser.add_parser("custom_blocks")
conversion_parser.add_argument(
"--block_module_name",
type=str,
default="block.py",
help="Module filename in which the custom block will be implemented.",
)
conversion_parser.add_argument(
"--block_class_name",
type=str,
default=None,
help="Name of the custom block. If provided None, we will try to infer it.",
)
conversion_parser.set_defaults(func=conversion_command_factory)
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
self.block_module_name = Path(block_module_name)
self.block_class_name = block_class_name
def run(self):
# determine the block to be saved.
out = self._get_class_names(self.block_module_name)
classes_found = list({cls for cls, _ in out})
if self.block_class_name is not None:
child_class, parent_class = self._choose_block(out, self.block_class_name)
if child_class is None and parent_class is None:
raise ValueError(
"`block_class_name` could not be retrieved. Available classes from "
f"{self.block_module_name}:\n{classes_found}"
)
else:
self.logger.info(
f"Found classes: {classes_found} will be using {classes_found[0]}. "
"If this needs to be changed, re-run the command specifying `block_class_name`."
)
child_class, parent_class = out[0][0], out[0][1]
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
# the user is responsible for running it, so I guess that is safe?
module_name = f"__dynamic__{self.block_module_name.stem}"
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
getattr(module, child_class)().save_pretrained(os.getcwd())
# or, we could create it manually.
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:
if cls == chosen:
return cls, base
return None, None
def _get_class_names(self, file_path):
source = file_path.read_text(encoding="utf-8")
try:
tree = ast.parse(source, filename=file_path)
except SyntaxError as e:
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
results: list[tuple[str, str]] = []
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
# extract all base names for this class
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
# for each allowed base that appears in the class's bases, emit a tuple
for allowed in EXPECTED_PARENT_CLASSES:
if allowed in base_names:
results.append((node.name, allowed))
return results
def _get_base_name(self, node: ast.expr):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
val = self._get_base_name(node.value)
return f"{val}.{node.attr}" if val else node.attr
return None
def _create_automap(self, parent_class, child_class):
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
return {"auto_map": auto_map}
-2
View File
@@ -15,7 +15,6 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -27,7 +26,6 @@ def main():
# Register commands
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
+1 -10
View File
@@ -176,7 +176,6 @@ class ConfigMixin:
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
subfolder = kwargs.pop("subfolder", None)
self._upload_folder(
save_directory,
@@ -184,7 +183,6 @@ class ConfigMixin:
token=token,
commit_message=commit_message,
create_pr=create_pr,
subfolder=subfolder,
)
@classmethod
@@ -603,10 +601,6 @@ class ConfigMixin:
value = value.tolist()
elif isinstance(value, Path):
value = value.as_posix()
elif hasattr(value, "to_dict") and callable(value.to_dict):
value = value.to_dict()
elif isinstance(value, list):
value = [to_json_saveable(v) for v in value]
return value
if "quantization_config" in config_dict:
@@ -763,7 +757,4 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls:
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
else:
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
+1 -1
View File
@@ -17,7 +17,7 @@ deps = {
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion==0.0.12",
"k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",
-39
View File
@@ -1,39 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
@@ -1,188 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_apg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled():
num_conditions += 1
return num_conditions
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
-190
View File
@@ -1,190 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AutoGuidance(BaseGuidance):
"""
AutoGuidance: https://huggingface.co/papers/2406.02507
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
self.auto_guidance_config = auto_guidance_config
self.dropout = dropout
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if auto_guidance_layers is None and auto_guidance_config is None:
raise ValueError(
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
)
if auto_guidance_layers is not None and auto_guidance_config is not None:
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
if (dropout is None and auto_guidance_layers is not None) or (
dropout is not None and auto_guidance_layers is None
):
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
if auto_guidance_layers is not None:
if isinstance(auto_guidance_layers, int):
auto_guidance_layers = [auto_guidance_layers]
if not isinstance(auto_guidance_layers, list):
raise ValueError(
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
)
auto_guidance_config = [
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
]
if isinstance(auto_guidance_config, dict):
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
if isinstance(auto_guidance_config, LayerSkipConfig):
auto_guidance_config = [auto_guidance_config]
if not isinstance(auto_guidance_config, list):
raise ValueError(
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
)
elif isinstance(next(iter(auto_guidance_config), None), dict):
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
self.auto_guidance_config = auto_guidance_config
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_ag_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_ag_enabled():
num_conditions += 1
return num_conditions
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -1,141 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -1,152 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeZeroStarGuidance(BaseGuidance):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
else:
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond_dtype = cond.dtype
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.to(dtype=cond_dtype)
-309
View File
@@ -1,309 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from ..configuration_utils import ConfigMixin
from ..utils import PushToHubMixin, get_logger
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
GUIDER_CONFIG_NAME = "guider_config.json"
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseGuidance(ConfigMixin, PushToHubMixin):
r"""Base class providing the skeleton for implementing guidance techniques."""
config_name = GUIDER_CONFIG_NAME
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
if not (start <= stop <= 1.0):
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
)
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
return self.forward(**forward_inputs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
) -> Self:
r"""
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
saved with [`~BaseGuidance.save_pretrained`].
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
<Tip>
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
`huggingface-cli login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
</Tip>
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
return_commit_hash=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a guider configuration object to a directory so that it can be reloaded using the
[`~BaseGuidance.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@@ -1,271 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
class PerturbedAttentionGuidance(BaseGuidance):
"""
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
layers.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
and implementation details.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
perturbed_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for perturbed attention guidance.
perturbed_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
perturbed_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
If not provided, `perturbed_guidance_config` must be provided.
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
# for each model architecture.
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
perturbed_guidance_scale: float = 2.8,
perturbed_guidance_start: float = 0.01,
perturbed_guidance_stop: float = 0.2,
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
self.skip_layer_guidance_start = perturbed_guidance_start
self.skip_layer_guidance_stop = perturbed_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if perturbed_guidance_config is None:
if perturbed_guidance_layers is None:
raise ValueError(
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
)
perturbed_guidance_config = LayerSkipConfig(
indices=perturbed_guidance_layers,
fqn="auto",
skip_attention=False,
skip_attention_scores=True,
skip_ff=False,
)
else:
if perturbed_guidance_layers is not None:
raise ValueError(
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
)
if isinstance(perturbed_guidance_config, dict):
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
if isinstance(perturbed_guidance_config, LayerSkipConfig):
perturbed_guidance_config = [perturbed_guidance_config]
if not isinstance(perturbed_guidance_config, list):
raise ValueError(
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
)
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
for config in perturbed_guidance_config:
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
logger.warning(
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
"Please check your configuration. Modifying the config to match the expected values."
)
config.skip_attention = False
config.skip_attention_scores = True
config.skip_ff = False
self.skip_layer_config = perturbed_guidance_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -1,262 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SkipLayerGuidance(BaseGuidance):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, dict):
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
elif isinstance(next(iter(skip_layer_config), None), dict):
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -1,251 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
future without warning or guarantee of reproducibility. This implementation assumes:
- Generated images are square (height == width)
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
together such as Flux)
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
seg_guidance_scale (`float`, defaults to `3.0`):
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
seg_blur_sigma (`float`, defaults to `9999999.0`):
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
The threshold above which the blur is considered infinite.
seg_guidance_start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
Diffusion 3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
seg_guidance_scale: float = 2.8,
seg_blur_sigma: float = 9999999.0,
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
self.seg_blur_sigma = seg_blur_sigma
self.seg_blur_threshold_inf = seg_blur_threshold_inf
self.seg_guidance_start = seg_guidance_start
self.seg_guidance_stop = seg_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= seg_guidance_start < 1.0):
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
if seg_guidance_layers is None and seg_guidance_config is None:
raise ValueError(
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
)
if seg_guidance_layers is not None and seg_guidance_config is not None:
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
if seg_guidance_layers is not None:
if isinstance(seg_guidance_layers, int):
seg_guidance_layers = [seg_guidance_layers]
if not isinstance(seg_guidance_layers, list):
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
if isinstance(seg_guidance_config, dict):
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]
if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
elif isinstance(next(iter(seg_guidance_config), None), dict):
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_cond_seg
pred = pred + self.seg_guidance_scale * shift
elif not self._is_seg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_seg = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_seg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -1,143 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class TangentialClassifierFreeGuidance(BaseGuidance):
"""
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_tcfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_tcfg_enabled():
num_conditions += 1
return num_conditions
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
) -> torch.Tensor:
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
Vh_modified = Vh.clone()
Vh_modified[:, 1] = 0
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred
-2
View File
@@ -20,7 +20,5 @@ if is_torch_available():
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
+1 -14
View File
@@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,16 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
@@ -34,10 +28,3 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None
-10
View File
@@ -107,7 +107,6 @@ class TransformerBlockRegistry:
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0
AttentionProcessorRegistry.register(
@@ -125,14 +124,6 @@ def _register_attention_processors_metadata():
),
)
# WanAttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -270,5 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on
+1 -2
View File
@@ -18,7 +18,6 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
@@ -568,7 +567,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
+2 -19
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -38,7 +37,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -83,7 +82,6 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -102,10 +100,7 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
all_tensors = []
for module in self.modules:
@@ -614,7 +609,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
@@ -627,7 +621,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -662,7 +655,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -694,7 +686,6 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
@@ -741,7 +732,6 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
@@ -763,7 +753,6 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=False,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -884,12 +873,6 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
raise ValueError("Group offloading is not enabled for the provided module.")
def _compute_group_hash(group_id):
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
# first 16 characters for a reasonably short but unique name
return hashed_id[:16]
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
-263
View File
@@ -1,263 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
skip_attention (`bool`, defaults to `True`):
Whether to skip attention blocks.
skip_ff (`bool`, defaults to `True`):
Whether to skip feed-forward blocks.
skip_attention_scores (`bool`, defaults to `False`):
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
projections as the output of scaled dot product attention.
dropout (`float`, defaults to `1.0`):
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
dropout: float = 1.0
def __post_init__(self):
if not (0 <= self.dropout <= 1):
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(data: dict) -> "LayerSkipConfig":
return LayerSkipConfig(**data)
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
query = kwargs.get("query", None)
key = kwargs.get("key", None)
value = kwargs.get("value", None)
query = query if query is not None else args[0]
key = key if key is not None else args[1]
value = value if value is not None else args[2]
# If the Q sequence length does not match KV sequence length, methods like
# Perturbed Attention Guidance cannot be used (because the caller expects
# the same sequence length as Q, but if we return V here, it will not match).
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
if query.shape[2] == value.shape[2]:
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
if not math.isclose(self.dropout, 1.0):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class FeedForwardSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class TransformerBlockSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
if self._metadata.return_encoder_hidden_states_index is None:
output = original_hidden_states
else:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
output = (original_hidden_states, original_encoder_hidden_states)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook(config.dropout)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
@@ -18,7 +18,6 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
@@ -228,7 +227,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if not isinstance(submodule, _ATTENTION_CLASSES):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
@@ -1,167 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict, dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from ..utils import get_logger
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
@dataclass
class SmoothedEnergyGuidanceConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
`None`, `to_q` is used by default.
"""
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: List[str] = None
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
return SmoothedEnergyGuidanceConfig(**data)
class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
super().__init__()
self.blur_sigma = blur_sigma
self.blur_threshold_inf = blur_threshold_inf
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
return smoothed_output
def _apply_smoothed_energy_guidance_hook(
module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
for identifier in config._query_proj_identifiers:
query_proj = getattr(submodule, identifier, None)
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
continue
logger.debug(
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
)
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
"""
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
Smoothed Energy Guidance.
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
seq_len_sqrt = int(math.sqrt(seq_len))
num_square_tokens = seq_len_sqrt * seq_len_sqrt
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
kernel1d = kernel1d.to(query)
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query
-4
View File
@@ -78,14 +78,12 @@ if is_torch_available():
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
"ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -103,7 +101,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
@@ -120,7 +117,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
SkyReelsV2LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
+4 -255
View File
@@ -40,6 +40,8 @@ if is_transformers_available():
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -352,256 +354,6 @@ class IPAdapterMixin:
self.unet.set_attn_processor(attn_procs)
class ModularIPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
if not USE_PEFT_BACKEND:
logger.warning("PEFT backend is required to load these weights.")
else:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
def set_ip_adapter_scale(self, scale):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
```
"""
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove hidden encoder
if self.unet is None:
return
self.unet.encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = None
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
self.unet.text_encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = "text_proj"
# restore original Unet attention processors layers
attn_procs = {}
for name, value in self.unet.attn_processors.items():
attn_processor_class = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
)
attn_procs[name] = (
attn_processor_class
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@@ -865,9 +617,6 @@ class FluxIPAdapterMixin:
>>> ...
```
"""
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
@@ -887,9 +636,9 @@ class FluxIPAdapterMixin:
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor()
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
+2 -5
View File
@@ -25,6 +25,7 @@ import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
@@ -330,8 +331,6 @@ def _load_lora_into_text_encoder(
hotswap: bool = False,
metadata=None,
):
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -443,8 +442,6 @@ def _func_optionally_disable_offloading(_pipeline):
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
from ..hooks.group_offloading import _is_group_offload_enabled
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
@@ -470,7 +467,7 @@ def _func_optionally_disable_offloading(_pipeline):
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
-398
View File
@@ -5454,404 +5454,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
elif any(k.startswith("lora_unet_") for k in state_dict):
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
@classmethod
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
def _maybe_expand_t2v_lora_for_i2v(
cls,
transformer: torch.nn.Module,
state_dict,
):
if transformer.config.image_dim is None:
return state_dict
target_device = transformer.device
if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
has_bias = any(".lora_B.bias" in k for k in state_dict)
if is_i2v_lora:
return state_dict
for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
# These keys should exist if the block `i` was part of the T2V LoRA.
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
continue
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
)
# If the original LoRA had biases (indicated by has_bias)
# AND the specific reference bias key exists for this block.
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
if has_bias and ref_key_lora_B_bias in state_dict:
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
ref_lora_B_bias_tensor,
device=target_device,
)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`SkyReelsV2Transformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class CogView4LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
+1 -3
View File
@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -163,8 +164,6 @@ class PeftAdapterMixin:
from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -696,7 +695,6 @@ class PeftAdapterMixin:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
@@ -24,7 +24,6 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -431,7 +430,6 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
empty_device_cache()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -46,7 +46,6 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -1690,7 +1689,6 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2150,7 +2148,6 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint)
+9 -7
View File
@@ -18,8 +18,11 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available():
@@ -81,12 +84,13 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
@@ -118,7 +122,7 @@ class FluxTransformer2DLoadersMixin:
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterAttnProcessor
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
@@ -154,8 +158,6 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
-4
View File
@@ -18,7 +18,6 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -81,8 +80,6 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
empty_device_cache()
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -150,7 +147,6 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_proj
+1 -6
View File
@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
@@ -43,7 +44,6 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -132,8 +132,6 @@ class UNet2DConditionLoadersMixin:
)
```
"""
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -754,7 +752,6 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
@@ -852,8 +849,6 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
-4
View File
@@ -26,7 +26,6 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -89,7 +88,6 @@ if is_torch_available():
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
@@ -113,7 +111,6 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
@@ -179,7 +176,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PriorTransformer,
SanaTransformer2DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
+3 -484
View File
@@ -11,504 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
if is_xformers_available():
import xformers as xops
else:
xops = None
logger = logging.get_logger(__name__)
class AttentionMixin:
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
"""
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.fuse_projections()
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.unfuse_projections()
class AttentionModuleMixin:
_default_processor_cls = None
_available_processors = []
fused_projections = False
def set_processor(self, processor: AttentionProcessor) -> None:
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def set_attention_backend(self, backend: str):
from .attention_dispatch import AttentionBackendName
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend.lower())
self.processor._attention_backend = backend
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
Set whether to use NPU flash attention from `torch_npu` or not.
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
if use_npu_flash_attention:
if not is_torch_npu_available():
raise ImportError("torch_npu is not available")
self.set_attention_backend("_native_npu")
def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
"""
Set whether to use XLA flash attention from `torch_xla` or not.
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
if use_xla_flash_attention:
if not is_torch_xla_available():
raise ImportError("torch_xla is not available")
self.set_attention_backend("_native_xla")
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
if is_xformers_available():
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
self.set_attention_backend("xformers")
@torch.no_grad()
def fuse_projections(self):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
# Skip if already fused
if getattr(self, "fused_projections", False):
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
else:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
# Skip if not fused
if not getattr(self, "fused_projections", False):
return
# Remove fused projection layers
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_qkv"):
delattr(self, "to_added_qkv")
self.fused_projections = False
def set_attention_slice(self, slice_size: int) -> None:
"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
processor = None
# Try to get a compatible processor for sliced attention
if slice_size is not None:
processor = self._get_compatible_processor("sliced")
# If no processor was found or slice_size is None, use default processor
if processor is None:
processor = self.default_processor_cls()
self.set_processor(processor)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
"""
Reshape the tensor for multi-head attention processing.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`): The attention mask to prepare.
target_length (`int`): The target length of the attention mask.
batch_size (`int`): The batch size for repeating the attention mask.
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
"""
Normalize the encoder hidden states.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
File diff suppressed because it is too large Load Diff
+663 -105
View File
@@ -2272,6 +2272,558 @@ class FusedAuraFlowAttnProcessor2_0:
return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -2901,6 +3453,106 @@ class XLAFlashAttnProcessor2_0:
return hidden_states
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
self.partition_spec = partition_spec
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query /= math.sqrt(head_dim)
hidden_states = flash_attention(query, key, value, causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
@@ -5340,6 +5992,17 @@ class LoRAAttnAddedKVProcessor:
pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
@@ -5504,111 +6167,6 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
return hidden_states
class FluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FusedFluxAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxAttnProcessor
return FluxAttnProcessor(*args, **kwargs)
class FluxIPAdapterJointAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
return FluxIPAdapterAttnProcessor(*args, **kwargs)
class FluxAttnProcessor2_0_NPU:
def __new__(cls, *args, **kwargs):
deprecation_message = (
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_native_npu"
return processor
class FusedFluxAttnProcessor2_0_NPU:
def __new__(self):
deprecation_message = (
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
"alternative solution to use NPU Flash Attention will be provided in the future."
)
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
from .transformers.transformer_flux import FluxAttnProcessor
processor = FluxAttnProcessor()
processor._attention_backend = "_fused_npu"
return processor
class XLAFluxFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __new__(cls, *args, **kwargs):
deprecation_message = (
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
"alternative solution to using XLA Flash Attention will be provided in the future."
)
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
if is_torch_xla_version("<", "2.3"):
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
if is_spmd() and is_torch_xla_version("<", "2.4"):
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
from .transformers.transformer_flux import FluxAttnProcessor
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
deprecation_message = (
"partition_spec was not used in the processor implementation when it was added. Passing it "
"is a no-op and support for it will be removed."
)
deprecate("partition_spec", "1.0.0", deprecation_message)
processor = FluxAttnProcessor(*args, **kwargs)
processor._attention_backend = "_native_xla"
return processor
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
if from_multi or len(control_type_idx) == 1:
if from_multi:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
if from_multi or len(control_type_idx) == 1:
if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
if from_multi or len(control_type_idx) == 1:
if from_multi:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
elif from_multi or len(control_type_idx) == 1:
elif from_multi:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
+34 -26
View File
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
"""
This function generates 1D positional embeddings from a grid.
@@ -352,11 +352,6 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
return emb
@@ -1181,7 +1176,6 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1199,15 +1193,8 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
if sequence_dim == 2:
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
@@ -1251,6 +1238,37 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
return x
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -2601,13 +2619,3 @@ class MultiIPAdapterImageProjection(nn.Module):
projected_image_embeds.append(image_embed)
return projected_image_embeds
class FluxPosEmbed(nn.Module):
def __new__(cls, *args, **kwargs):
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
from .transformers.transformer_flux import FluxPosEmbed
return FluxPosEmbed(*args, **kwargs)
+1 -64
View File
@@ -16,10 +16,9 @@
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -39,7 +38,6 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -254,10 +252,6 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -526,60 +520,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+41 -89
View File
@@ -62,14 +62,10 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -172,11 +168,7 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
for name, param in parameter.named_parameters():
last_dtype = param.dtype
if (
hasattr(parameter, "_keep_in_fp32_modules")
and parameter._keep_in_fp32_modules
and any(m in name for m in parameter._keep_in_fp32_modules)
):
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
continue
if param.is_floating_point():
@@ -610,60 +602,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_to_disk_path=offload_to_disk_path,
)
def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.
Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -1531,6 +1469,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1539,27 +1482,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
else:
if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1569,14 +1503,38 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
@@ -1596,9 +1554,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
empty_device_cache()
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
@@ -1935,9 +1892,4 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls:
return super(LegacyModelMixin, remapped_class).from_pretrained(
pretrained_model_name_or_path, **kwargs_copy
)
else:
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
@@ -31,7 +31,6 @@ if is_torch_available():
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
@@ -24,13 +24,19 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -217,8 +223,6 @@ class ChromaSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -226,15 +230,17 @@ class ChromaSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor()
processor = FluxAttnProcessor2_0()
self.attn = FluxAttention(
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -286,15 +292,17 @@ class ChromaTransformerBlock(nn.Module):
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = FluxAttention(
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor(),
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
@@ -368,13 +376,7 @@ class ChromaTransformerBlock(nn.Module):
class ChromaTransformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux, modified for Chroma.
@@ -473,6 +475,106 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
@@ -187,15 +187,9 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
apply_rotary_emb,
get_1d_rotary_pos_embed,
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -42,307 +42,6 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class FluxAttnProcessor:
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
ip_query = query
if encoder_hidden_states is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_attn_output = torch.zeros_like(hidden_states)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -355,8 +54,6 @@ class FluxSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -364,15 +61,17 @@ class FluxSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor()
processor = FluxAttnProcessor2_0()
self.attn = FluxAttention(
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -419,15 +118,17 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = FluxAttention(
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor(),
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
@@ -451,7 +152,6 @@ class FluxTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -480,6 +180,7 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
@@ -494,45 +195,8 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class FluxTransformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux.
@@ -628,6 +292,106 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
@@ -726,7 +490,6 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -758,7 +521,6 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -1,607 +0,0 @@
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import (
PixArtAlphaTextProjection,
TimestepEmbedding,
get_1d_rotary_pos_embed,
get_1d_sincos_pos_embed_from_grid,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin, get_parameter_dtype
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SkyReelsV2AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
class SkyReelsV2ImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class SkyReelsV2Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
super().__init__()
self.num_channels = num_channels
self.output_type = output_type
self.flip_sin_to_cos = flip_sin_to_cos
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
original_shape = timesteps.shape
t_emb = get_1d_sincos_pos_embed_from_grid(
self.num_channels,
timesteps,
output_type=self.output_type,
flip_sin_to_cos=self.flip_sin_to_cos,
)
# Reshape back to maintain batch structure
if len(original_shape) > 1:
t_emb = t_emb.reshape(*original_shape, self.num_channels)
return t_emb
class SkyReelsV2TimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = get_parameter_dtype(self.time_embedder)
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class SkyReelsV2RotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
class SkyReelsV2TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
)
# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
if temb.dim() == 3:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
elif temb.dim() == 4:
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `16`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `4096`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `8192`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `32`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
inject_sample_info (`bool`, defaults to `False`):
Whether to inject sample information into the model.
image_dim (`int`, *optional*):
The dimension of the image embeddings.
added_kv_proj_dim (`int`, *optional*):
The dimension of the added key/value projection.
rope_max_seq_len (`int`, defaults to `1024`):
The maximum sequence length for the rotary embeddings.
pos_embed_seq_len (`int`, *optional*):
The sequence length for the positional embeddings.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["SkyReelsV2TransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 16,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 8192,
num_layers: int = 32,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
inject_sample_info: bool = False,
num_frame_per_block: int = 1,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
SkyReelsV2TransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
if inject_sample_info:
self.fps_embedding = nn.Embedding(2, inner_dim)
self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
enable_diffusion_forcing: bool = False,
fps: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
rotary_emb = self.rope(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
causal_mask = None
if self.config.num_frame_per_block > 1:
block_num = post_patch_num_frames // self.config.num_frame_per_block
range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
self.config.num_frame_per_block
)
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
causal_mask = causal_mask.repeat(
1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
)
causal_mask = causal_mask.reshape(
post_patch_num_frames * post_patch_height * post_patch_width,
post_patch_num_frames * post_patch_height * post_patch_width,
)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
if self.config.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
fps_emb = self.fps_embedding(fps)
if enable_diffusion_forcing:
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
timestep.shape[1], 1, 1
)
else:
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
if enable_diffusion_forcing:
b, f = timestep.shape
temb = temb.view(b, f, 1, 1, -1)
timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
1, 3
) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
causal_mask,
)
else:
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
causal_mask,
)
if temb.dim() == 2:
# If temb is 2D, we assume it has time 1-D time embedding values for each batch.
# For models:
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
elif temb.dim() == 3:
# If temb is 3D, we assume it has 2-D time embedding values for each batch.
# Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
# For models:
# - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = shift.squeeze(1), scale.squeeze(1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _set_ar_attention(self, causal_block_size: int):
self.register_to_config(num_frame_per_block=causal_block_size)
@@ -165,7 +165,7 @@ class UNet2DConditionModel(
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["BasicTransformerBlock"]
@@ -1,86 +0,0 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
"PipelineState",
"BlockState",
]
_import_structure["modular_pipeline_utils"] = [
"ComponentSpec",
"ConfigSpec",
"InputParam",
"OutputParam",
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
InsertableDict,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,673 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import is_torch_available, logging
if is_torch_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class InsertableDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
def __repr__(self):
if not self:
return "InsertableDict()"
items = []
for i, (key, value) in enumerate(self.items()):
if isinstance(value, type):
# For classes, show class name and <class ...>
obj_repr = f"<class '{value.__module__}.{value.__name__}'>"
else:
# For objects (instances) and other types, show class name and module
obj_repr = f"<obj '{value.__class__.__module__}.{value.__class__.__name__}'>"
items.append(f"{i}: ({repr(key)}, {obj_repr})")
return "InsertableDict([\n " + ",\n ".join(items) + "\n])"
# YiYi TODO:
# 1. validate the dataclass fields
# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained()
@dataclass
class ComponentSpec:
"""Specification for a pipeline component.
A component can be created in two ways:
1. From scratch using __init__ with a config dict
2. using `from_pretrained`
Attributes:
name: Name of the component
type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component
config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation
subfolder: Optional subfolder in repo
variant: Optional variant in repo
revision: Optional revision in repo
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
"""
name: Optional[str] = None
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (
self.name == other.name
and self.load_id == other.load_id
and self.default_creation_method == other.default_creation_method
)
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component.
Currently supports:
- Components created with `ComponentSpec.load()` method
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
Args:
name: Name of the component
component: Component object to create spec from
Returns:
ComponentSpec object
Raises:
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
"""
# Check if component was created with ComponentSpec.load()
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
# component has a usable load_id -> from_pretrained, no warning needed
default_creation_method = "from_pretrained"
else:
# Component doesn't have a usable load_id, check if it's a nn.Module
if isinstance(component, torch.nn.Module):
raise ValueError(
"Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
)
# ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
elif isinstance(component, ConfigMixin):
# warn if component was not created with `ComponentSpec`
if not hasattr(component, "_diffusers_load_id"):
logger.warning(
"Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
)
default_creation_method = "from_config"
else:
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
raise ValueError(
f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
f" - components created with `ComponentSpec.load()` method"
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
)
type_hint = component.__class__
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
config = component.config
else:
config = None
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
load_spec = cls.decode_load_id(component._diffusers_load_id)
else:
load_spec = {}
return cls(
name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
)
@classmethod
def loading_fields(cls) -> List[str]:
"""
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
segments).
"""
if self.default_creation_method == "from_config":
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
Decode a load_id string back into a dictionary of loading fields and values.
Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
where None values are represented as "null"
Returns:
Dict mapping loading field names to their values. e.g. {
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
component not created with `load` method).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
# remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError("`type_hint` is required when using from_config creation method.")
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
signature_params = inspect.signature(self.type_hint.__init__).parameters
init_kwargs = {}
for k, v in config.items():
if k in signature_params:
init_kwargs[k] = v
for k, v in kwargs.items():
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
)
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
default: Any
description: Optional[str] = None
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
# -> should we use different class for inputs and intermediate_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return (
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)
def format_inputs_short(inputs):
"""
Format input parameters into a string representation, with required params first followed by optional ones.
Args:
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
Returns:
str: Formatted string of input parameters
Example:
>>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ...
InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps",
required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5,
num_inference_steps=50'
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
"""
Formats intermediate inputs and outputs of a block into a string representation.
Args:
intermediate_inputs: List of intermediate input parameters
required_intermediate_inputs: List of required intermediate input names
intermediate_outputs: List of intermediate output parameters
Returns:
str: Formatted string like:
Intermediates:
- inputs: Required(latents), dtype
- modified: latents # variables that appear in both inputs and outputs
- outputs: images # new outputs only
"""
# Handle inputs
input_parts = []
for inp in intermediate_inputs:
if inp.name in required_intermediate_inputs:
input_parts.append(f"Required({inp.name})")
else:
if inp.name is None and inp.kwargs_type is not None:
inp_name = "*_" + inp.kwargs_type
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediate_inputs}
modified_parts = []
new_output_parts = []
for out in intermediate_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
if modified_parts:
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""Format a list of InputParam or OutputParam objects into a readable string representation.
Args:
params: List of InputParam or OutputParam objects to format
header: Header text to use (e.g. "Args" or "Returns")
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all parameters
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
lines = []
current_line = []
current_length = 0
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
current_length = len(word)
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
# YiYi Notes: remove this line if we remove kwargs_type
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
def format_input_params(input_params, indent_level=4, max_line_length=115):
"""Format a list of InputParam objects into a readable string representation.
Args:
input_params: List of InputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all input parameters
"""
return format_params(input_params, "Inputs", indent_level, max_line_length)
def format_output_params(output_params, indent_level=4, max_line_length=115):
"""Format a list of OutputParam objects into a readable string representation.
Args:
output_params: List of OutputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all output parameters
"""
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
Args:
components: List of ComponentSpec objects to format
indent_level: Number of spaces to indent each component line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between components (default: True)
Returns:
A formatted string representing all components
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = (
component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ConfigSpec objects into a readable string representation.
Args:
configs: List of ConfigSpec objects to format
indent_level: Number of spaces to indent each config line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between configs (default: True)
Returns:
A formatted string representing all configs
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
def make_doc_string(
inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
expected_components=None,
expected_configs=None,
):
"""
Generates a formatted documentation string describing the pipeline block's parameters and structure.
Args:
inputs: List of input parameters
intermediate_inputs: List of intermediate input parameters
outputs: List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
expected_components (List[ComponentSpec], *optional*): List of expected components
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
Returns:
str: A formatted string containing information about components, configs, call parameters,
intermediate inputs/outputs, and final outputs.
"""
output = ""
# Add class name if provided
if class_name:
output += f"class {class_name}\n\n"
# Add description
if description:
desc_lines = description.strip().split("\n")
aligned_desc = "\n".join(" " + line for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
@@ -1,665 +0,0 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS = {
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
"""
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components_manager, collection=None):
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
# Save the config to file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.pipeline.update_components(**params_components)
output = self.pipeline(**params_run, output=return_output_names)
return output
@@ -1,77 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"IP_ADAPTER_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLAutoControlnetStep",
"StableDiffusionXLAutoDecodeStep",
"StableDiffusionXLAutoIPAdapterStep",
"StableDiffusionXLAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import (
StableDiffusionXLTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
TEXT2IMAGE_BLOCKS,
StableDiffusionXLAutoBlocks,
StableDiffusionXLAutoControlnetStep,
StableDiffusionXLAutoDecodeStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
)
from .modular_pipeline import StableDiffusionXLModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
@@ -1,215 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
def upcast_vae(components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = (
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
)
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
# apply watermark if available
if hasattr(components, "watermark") and components.watermark is not None:
block_state.images = components.watermark.apply_watermark(block_state.images)
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"A post-processing step that overlays the mask on the image (inpainting task only).\n"
+ "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images from the decode step",
),
InputParam(
"crops_coords",
type_hint=Tuple[int, int],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
block_state.images = [
components.image_processor.apply_overlay(
block_state.mask_image, block_state.image, i, block_state.crops_coords
)
for i in block_state.images
]
self.set_block_state(state, block_state)
return components, state
@@ -1,791 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import ControlNetModel, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that prepare the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
return components, block_state
# loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object"
)
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
InputParam(
"masked_image_latents",
type_hint=Optional[torch.Tensor],
description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
]
@staticmethod
def check_inputs(components, block_state):
num_channels_unet = components.num_channels_unet
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
if block_state.mask is None or block_state.masked_image_latents is None:
raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
num_channels_latents = block_state.latents.shape[1]
num_channels_mask = block_state.mask.shape[1]
num_channels_masked_image = block_state.masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
raise ValueError(
f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `components.unet` or your `mask_image` or `image` input."
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
if components.num_channels_unet == 9:
block_state.scaled_latents = torch.cat(
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
)
return components, block_state
# loop step (2): denoise the latents with guidance
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=cond_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that denoise the latents with guidance (with controlnet). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
required=True,
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"conditioning_scale",
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"guess_mode",
required=True,
type_hint=bool,
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
InputParam(
kwargs_type="controlnet_kwargs",
description=(
"additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
"please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
extra_controlnet_kwargs = self.prepare_extra_kwargs(
components.controlnet.forward, **block_state.controlnet_kwargs
)
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
# cond_scale for the timestep (controlnet input)
if isinstance(block_state.controlnet_keep[i], list):
block_state.cond_scale = [
c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])
]
else:
controlnet_cond_scale = block_state.conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
# default controlnet output/unet input for guess mode + conditional path
block_state.down_block_res_samples_zeros = None
block_state.mid_block_res_sample_zeros = None
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
# Prepare additional conditionings
added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
# Prepare controlnet additional conditionings
controlnet_added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
# run controlnet for the guidance batch
if block_state.guess_mode and not components.guider.is_conditional:
# guider always run uncond batch first, so these tensors should be set already
down_block_res_samples = block_state.down_block_res_samples_zeros
mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.controlnet_cond,
conditioning_scale=block_state.cond_scale,
guess_mode=block_state.guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
**extra_controlnet_kwargs,
)
# assign it to block_state so it will be available for the uncond guidance batch
if block_state.down_block_res_samples_zeros is None:
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
if block_state.mid_block_res_sample_zeros is None:
block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
# Predict the noise
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (3): scheduler step to update latents
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
# YiYi TODO: move this out of here
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
return components, block_state
# loop step (3): scheduler step to update latents (with inpainting)
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents (for inpainting workflow only). "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
),
InputParam(
"noise",
type_hint=Optional[torch.Tensor],
description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
type_hint=Optional[torch.Tensor],
description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
def check_inputs(self, components, block_state):
if components.num_channels_unet == 4:
if block_state.image_latents is None:
raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
if block_state.mask is None:
raise ValueError(f"mask is required for this step {self.__class__.__name__}")
if block_state.noise is None:
raise ValueError(f"noise is required for this step {self.__class__.__name__}")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
# adjust latent for inpainting
if components.num_channels_unet == 4:
block_state.init_latents_proper = block_state.image_latents
if i < len(block_state.timesteps) - 1:
block_state.noise_timestep = block_state.timesteps[i + 1]
block_state.init_latents_proper = components.scheduler.add_noise(
block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
)
block_state.latents = (
1 - block_state.mask
) * block_state.init_latents_proper + block_state.mask * block_state.latents
return components, block_state
# the loop wrapper that iterates over the timesteps
class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if block_state.disable_guidance:
components.guider.disable()
else:
components.guider.enable()
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
# composing the denoising loops
class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLLoopBeforeDenoiser,
StableDiffusionXLLoopDenoiser,
StableDiffusionXLLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
"This block supports both text2img and img2img tasks."
)
# control_cond
class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLLoopBeforeDenoiser,
StableDiffusionXLControlNetLoopDenoiser,
StableDiffusionXLLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
"This block supports using controlnet for both text2img and img2img tasks."
)
# mask
class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLInpaintLoopBeforeDenoiser,
StableDiffusionXLLoopDenoiser,
StableDiffusionXLInpaintLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
"This block onlysupports inpainting tasks."
)
# control_cond + mask
class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [
StableDiffusionXLInpaintLoopBeforeDenoiser,
StableDiffusionXLControlNetLoopDenoiser,
StableDiffusionXLInpaintLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
"This block only supports using controlnet for inpainting tasks."
)
@@ -1,902 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"IP Adapter step that prepares ip adapter image embeddings.\n"
"Note that this step only prepares the embeddings - in order for it to work correctly, "
"you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
ComponentSpec(
"feature_extractor",
CLIPImageProcessor,
config=FrozenDict({"size": 224, "crop_size": 224}),
default_creation_method="from_config",
),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter",
)
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=torch.Tensor,
description="Negative IP adapter image embeddings",
),
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = components.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = components.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = components.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
components,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
prepare_unconditional_embeds,
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("clip_skip"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative pooled text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = (
[components.tokenizer, components.tokenizer_2]
if components.tokenizer is not None
else [components.tokenizer_2]
)
text_encoders = (
[components.text_encoder, components.text_encoder_2]
if components.text_encoder is not None
else [components.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=components.text_encoder_2.dtype, device=device
)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None)
if block_state.cross_attention_kwargs is not None
else None
)
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict(
{"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Vae encoder step that prepares the image and mask for the inpainting process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam(
"masked_image_latents",
type_hint=torch.Tensor,
description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
),
OutputParam(
"crops_coords",
type_hint=Optional[Tuple[int, int]],
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.height is None:
block_state.height = components.default_height
if block_state.width is None:
block_state.width = components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
)
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(
block_state.image,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode,
)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(
block_state.mask_image,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state
@@ -1,380 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLInputStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintOverlayMaskStep,
)
from .denoise import (
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLInpaintControlNetDenoiseStep,
StableDiffusionXLInpaintDenoiseStep,
)
from .encoders import (
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# auto blocks & sequential blocks & mappings
# vae encoder (run before before_denoise)
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for both inpainting and img2img tasks.\n"
+ " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if neither `mask_image` nor `image` is provided, step will be skipped."
)
# optional ip-adapter (run before input step)
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
# before_denoise: text2img
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: img2img
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: inpainting
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
)
# before_denoise: all task (text2img, img2img, inpainting)
class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
StableDiffusionXLInpaintBeforeDenoiseStep,
StableDiffusionXLImg2ImgBeforeDenoiseStep,
StableDiffusionXLBeforeDenoiseStep,
]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n"
+ " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n"
+ " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n"
)
# optional controlnet input step (after before_denoise, before denoise)
# works for both controlnet and controlnet_union
class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
block_names = ["controlnet_union", "controlnet"]
block_trigger_inputs = ["control_mode", "control_image"]
@property
def description(self):
return (
"Controlnet Input step that prepare the controlnet input.\n"
+ "This is an auto pipeline block that works for both controlnet and controlnet_union.\n"
+ " (it should be called right before the denoise step)"
+ " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n"
+ " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided."
+ " - if neither `control_mode` nor `control_image` is provided, step will be skipped."
)
# denoise: controlnet (text2img, img2img, inpainting)
class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep]
block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
block_trigger_inputs = ["mask", "controlnet_cond"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. "
"This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks."
"This block should not be used without a controlnet_cond input"
" - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided."
" - If neither mask nor controlnet_cond are provided, step will be skipped."
)
# denoise: all task with or without controlnet (text2img, img2img, inpainting)
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
StableDiffusionXLAutoControlNetDenoiseStep,
StableDiffusionXLInpaintDenoiseStep,
StableDiffusionXLDenoiseStep,
]
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", "mask", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
" - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)."
" - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)."
" - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)."
)
# decode: inpaint
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
block_names = ["decode", "mask_overlay"]
@property
def description(self):
return (
"Inpaint decode step that decode the denoised latents into images outputs.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n"
+ " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
)
# decode: all task (text2img, img2img, inpainting)
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
block_names = ["inpaint", "non-inpaint"]
block_trigger_inputs = ["padding_mask_crop", None]
@property
def description(self):
return (
"Decode step that decode the denoised latents into images outputs.\n"
+ "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n"
+ " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n"
+ " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
)
# ip-adapter, controlnet, text2img, img2img, inpainting
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLAutoDecodeStep,
]
block_names = [
"text_encoder",
"ip_adapter",
"image_encoder",
"before_denoise",
"controlnet_input",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
# controlnet (input + denoise step)
class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoControlNetDenoiseStep,
]
block_names = ["controlnet_input", "controlnet_denoise"]
@property
def description(self):
return (
"Controlnet auto step that prepare the controlnet input and denoise the latents. "
+ "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
+ " (it should be replace at 'denoise' step)"
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLInpaintDenoiseStep),
("decode", StableDiffusionXLInpaintDecodeStep),
]
)
CONTROLNET_BLOCKS = InsertableDict(
[
("denoise", StableDiffusionXLAutoControlnetStep),
]
)
IP_ADAPTER_BLOCKS = InsertableDict(
[
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep),
]
)
ALL_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS,
}
@@ -1,376 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ...image_processor import PipelineImageInput
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularPipeline
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
# YiYi Notes: model specific components:
## (1) it should inherit from ModularPipeline
## (2) acts like a container that holds components and configs
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularPipeline(
ModularPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
"""
A ModularPipeline for Stable Diffusion XL.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
default_sample_size = 128
if hasattr(self, "unet") and self.unet is not None:
default_sample_size = self.unet.config.sample_size
return default_sample_size
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_unet(self):
num_channels_unet = 4
if hasattr(self, "unet") and self.unet is not None:
num_channels_unet = self.unet.config.in_channels
return num_channels_unet
@property
def num_channels_latents(self):
num_channels_latents = 4
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
# auto_docstring
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam(
"prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"
),
"negative_prompt_embeds": OutputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": OutputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": OutputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam(
"image_latents", type_hint=torch.Tensor, description="Latents representing reference image"
),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam(
"latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"
),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="Generated images",
),
}
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam(
"images",
type_hint=Union[
Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput
],
description="The final generated images",
)
}
@@ -1,66 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
)
from .modular_pipeline import WanModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,365 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import torch
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
# configuration of guider is.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class WanInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
"have a final batch_size of batch_size * num_videos_per_prompt."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_videos_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
description="negative text embeddings used to guide the image generation",
),
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
f" {block_state.negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_videos_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
block_state.device,
block_state.timesteps,
block_state.sigmas,
)
self.set_block_state(state, block_state)
return components, state
class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("num_frames", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_videos_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
if block_state.num_frames is not None and (
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
):
raise ValueError(
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
)
@staticmethod
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
def prepare_latents(
comp,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // comp.vae_scale_factor_spatial,
int(width) // comp.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.num_frames = block_state.num_frames or components.default_num_frames
block_state.device = components._execution_device
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.latents = self.prepare_latents(
components,
block_state.batch_size * block_state.num_videos_per_prompt,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.num_frames,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
self.set_block_state(state, block_state)
return components, state
@@ -1,105 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan
from ...utils import logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"videos",
type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype
if not block_state.output_type == "latent":
latents = block_state.latents
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
else:
block_state.videos = block_state.latents
block_state.videos = components.video_processor.postprocess_video(
block_state.videos, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state
@@ -1,261 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import WanTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred.float(),
t,
block_state.latents.float(),
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", UniPCMultistepScheduler),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class WanDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
)
@@ -1,242 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import List, Optional, Union
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", UMT5EncoderModel),
ComponentSpec("tokenizer", AutoTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="guider_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]],
max_sequence_length: int,
device: torch.device,
):
dtype = components.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: str,
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, defaults to `512`):
The maximum number of text tokens to be used for the generation process.
"""
device = device or components._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
if prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
components, negative_prompt, max_sequence_length, device
)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
@@ -1,144 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanInputStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
)
from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep
from .encoders import WanTextEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2vid,)
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanBeforeDenoiseStep,
]
block_names = ["text2vid"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2vid.\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
)
# denoise: text2vid
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanDenoiseStep,
]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2vid tasks.."
" - `WanDenoiseStep` (denoise) for text2vid tasks."
)
# decode: all task (text2img, img2img, inpainting)
class WanAutoDecodeStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep]
block_names = ["non-inpaint"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
# text2vid
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoBeforeDenoiseStep,
WanAutoDenoiseStep,
WanAutoDecodeStep,
]
block_names = [
"text_encoder",
"before_denoise",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep),
("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep),
]
)
ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS,
"auto": AUTO_BLOCKS,
}
@@ -1,90 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanModularPipeline(
ModularPipeline,
StableDiffusionMixin,
WanLoraLoaderMixin,
):
"""
A ModularPipeline for Wan.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
@property
def default_width(self):
return self.default_sample_width * self.vae_scale_factor_spatial
@property
def default_num_frames(self):
return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
@property
def default_sample_height(self):
return 60
@property
def default_sample_width(self):
return 104
@property
def default_sample_num_frames(self):
return 21
@property
def vae_scale_factor_spatial(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
return vae_scale_factor
@property
def vae_scale_factor_temporal(self):
vae_scale_factor = 4
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
return vae_scale_factor
@property
def num_channels_transformer(self):
num_channels_transformer = 16
if hasattr(self, "transformer") and self.transformer is not None:
num_channels_transformer = self.transformer.config.in_channels
return num_channels_transformer
@property
def num_channels_latents(self):
num_channels_latents = 16
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.z_dim
return num_channels_latents
-15
View File
@@ -380,13 +380,6 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
"SkyReelsV2DiffusionForcingVideoToVideoPipeline",
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -858,14 +851,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SpectrogramDiffusionPipeline,
)
from .skyreels_v2 import (
SkyReelsV2DiffusionForcingImageToVideoPipeline,
SkyReelsV2DiffusionForcingPipeline,
SkyReelsV2DiffusionForcingVideoToVideoPipeline,
SkyReelsV2ImageToVideoPipeline,
SkyReelsV2Pipeline,
)
else:
import sys
+7 -8
View File
@@ -248,15 +248,14 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
def _get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
model_name = _get_model(pipeline_class_name)
def get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
model_name = get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -663,11 +663,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):

Some files were not shown because too many files have changed in this diff Show More