Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1662890767 | |||
| 9c0944581a | |||
| 9dc99bb069 | |||
| 4588bbeb42 | |||
| d9510862bf | |||
| 056fb8ad98 | |||
| ca913f0db4 | |||
| 769c56af6f | |||
| 1222b966d7 | |||
| 024932dd19 | |||
| ec5449f3a1 | |||
| 310fdaf556 | |||
| dcb6dd9b7a | |||
| 043ab2520f | |||
| 08c29020dd | |||
| 7a58734994 | |||
| 9ef118509e | |||
| 7c54a7b38a | |||
| 09e777a3e1 | |||
| a72bc0c4bb | |||
| 80de641c1c | |||
| 76810eca2b | |||
| 1448b03585 | |||
| 5796735015 | |||
| d8310a8fca | |||
| 78031c2938 | |||
| d83d35c1bb | |||
| 843355f89f | |||
| c006a95df1 | |||
| df267ee4e8 | |||
| edd614ea38 | |||
| 7e7e62c6ff | |||
| eda9ff8300 | |||
| efb7a299af | |||
| d06750a5fd | |||
| 8c72cd12ee | |||
| 751e250f70 | |||
| b50014067d | |||
| f5c113e439 | |||
| 5e181eddfe | |||
| 55f0b3d758 | |||
| eb7ef26736 | |||
| e1b7f1f240 | |||
| 9e7ae568d6 | |||
| f7b79452b4 | |||
| 43459079ab | |||
| 4067d6c4b6 | |||
| 28106fcac4 | |||
| c222570a9b | |||
| 4e36bb0d23 | |||
| f50b18eec7 |
@@ -110,8 +110,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -116,8 +116,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -253,9 +254,10 @@ jobs:
|
||||
python -m uv pip install -e [quality,test]
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -U tokenizers
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# python -m uv pip install -U tokenizers
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -132,8 +132,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -203,8 +204,9 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -266,7 +268,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -23,11 +23,7 @@
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Load schedulers and models
|
||||
- local: using-diffusers/models
|
||||
title: Models
|
||||
- local: using-diffusers/scheduler_features
|
||||
title: Scheduler features
|
||||
title: Schedulers
|
||||
- local: using-diffusers/other-formats
|
||||
title: Model files and layouts
|
||||
- local: using-diffusers/push_to_hub
|
||||
@@ -68,10 +64,14 @@
|
||||
title: Accelerate inference
|
||||
- local: optimization/cache
|
||||
title: Caching
|
||||
- local: optimization/attention_backends
|
||||
title: Attention backends
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compiling and offloading quantized models
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
- title: Community optimizations
|
||||
sections:
|
||||
- local: optimization/pruna
|
||||
@@ -82,6 +82,8 @@
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
- local: optimization/cache_dit
|
||||
title: CacheDiT
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
- local: optimization/xdit
|
||||
|
||||
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
|
||||
|
||||
[[autodoc]] image_processor.VaeImageProcessor
|
||||
|
||||
## InpaintProcessor
|
||||
|
||||
The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
|
||||
|
||||
[[autodoc]] image_processor.InpaintProcessor
|
||||
|
||||
## VaeImageProcessorLDM3D
|
||||
|
||||
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Parallelism
|
||||
|
||||
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
|
||||
|
||||
## ParallelConfig
|
||||
|
||||
[[autodoc]] ParallelConfig
|
||||
|
||||
## ContextParallelConfig
|
||||
|
||||
[[autodoc]] ContextParallelConfig
|
||||
|
||||
[[autodoc]] hooks.apply_context_parallel
|
||||
@@ -50,7 +50,7 @@ from diffusers.utils import export_to_video
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="torchao",
|
||||
quant_kwargs={"quant_type": "int8wo"},
|
||||
components_to_quantize=["transformer"]
|
||||
components_to_quantize="transformer"
|
||||
)
|
||||
|
||||
# fp8 layerwise weight-casting
|
||||
|
||||
@@ -54,7 +54,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16
|
||||
},
|
||||
components_to_quantize=["transformer"]
|
||||
components_to_quantize="transformer"
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
@@ -91,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16
|
||||
},
|
||||
components_to_quantize=["transformer"]
|
||||
components_to_quantize="transformer"
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
@@ -139,7 +139,7 @@ export_to_video(video, "output.mp4", fps=15)
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16
|
||||
},
|
||||
components_to_quantize=["transformer"]
|
||||
components_to_quantize="transformer"
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
|
||||
@@ -26,6 +26,7 @@ Qwen-Image comes in the following variants:
|
||||
|:----------:|:--------:|
|
||||
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
|
||||
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
|
||||
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -96,6 +97,29 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
|
||||
|
||||
</Tip>
|
||||
|
||||
## Multi-image reference with QwenImageEditPlusPipeline
|
||||
|
||||
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
|
||||
|
||||
```
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt="put the penguin and the cat at a game show called "Qwen Edit Plus Games"",
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
@@ -126,7 +150,15 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImaggeControlNetPipeline
|
||||
## QwenImageControlNetPipeline
|
||||
|
||||
[[autodoc]] QwenImageControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageEditPlusPipeline
|
||||
|
||||
[[autodoc]] QwenImageEditPlusPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Attention backends
|
||||
|
||||
> [!NOTE]
|
||||
> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
|
||||
|
||||
Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
|
||||
|
||||
Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
|
||||
|
||||
| attention family | main feature |
|
||||
|---|---|
|
||||
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
|
||||
| SageAttention | quantizes attention to int8 |
|
||||
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
|
||||
| xFormers | memory-efficient attention with support for various attention kernels |
|
||||
|
||||
This guide will show you how to set and use the different attention backends.
|
||||
|
||||
## set_attention_backend
|
||||
|
||||
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
|
||||
|
||||
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
|
||||
|
||||
> [!NOTE]
|
||||
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_flash_3_hub")
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.reset_attention_backend()
|
||||
```
|
||||
|
||||
## attention_backend context manager
|
||||
|
||||
The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
with attention_backend("_flash_3_hub"):
|
||||
image = pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
|
||||
|
||||
## Available backends
|
||||
|
||||
Refer to the table below for a complete list of available attention backends and their variants.
|
||||
|
||||
<details>
|
||||
<summary>Expand</summary>
|
||||
|
||||
| Backend Name | Family | Description |
|
||||
|--------------|--------|-------------|
|
||||
| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
|
||||
| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
|
||||
| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
|
||||
| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
|
||||
| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
|
||||
| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
|
||||
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
|
||||
|
||||
</details>
|
||||
@@ -0,0 +1,270 @@
|
||||
## CacheDiT
|
||||
|
||||
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
|
||||
|
||||
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
|
||||
|
||||
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="PyPI">
|
||||
|
||||
```bash
|
||||
pip3 install -U cache-dit
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="source">
|
||||
|
||||
```bash
|
||||
pip3 install git+https://github.com/vipshop/cache-dit.git
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Run the command below to view supported DiT pipelines.
|
||||
|
||||
```python
|
||||
>>> import cache_dit
|
||||
>>> cache_dit.supported_pipelines()
|
||||
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
|
||||
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
|
||||
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
|
||||
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
|
||||
```
|
||||
|
||||
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
|
||||
|
||||
|
||||
## Unified Cache API
|
||||
|
||||
CacheDiT works by matching specific input/output patterns as shown below.
|
||||
|
||||

|
||||
|
||||
Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
# Can be any diffusion pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
|
||||
|
||||
# One-line code with default cache options.
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Just call the pipe as normal.
|
||||
output = pipe(...)
|
||||
|
||||
# Disable cache and run original pipe.
|
||||
cache_dit.disable_cache(pipe)
|
||||
```
|
||||
|
||||
## Automatic Block Adapter
|
||||
|
||||
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
|
||||
|
||||
|
||||
```python
|
||||
from cache_dit import ForwardPattern, BlockAdapter
|
||||
|
||||
# Use 🔥BlockAdapter with `auto` mode.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
# Any DiffusionPipeline, Qwen-Image, etc.
|
||||
pipe=pipe, auto=True,
|
||||
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
||||
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
|
||||
# Or, manually setup transformer configurations.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # Qwen-Image, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=pipe.transformer.transformer_blocks,
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
|
||||
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
|
||||
|
||||
```python
|
||||
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
|
||||
# single_transformer_blocks have different forward patterns.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # FLUX.1, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.transformer_blocks,
|
||||
pipe.transformer.single_transformer_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_1,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
|
||||
|
||||
## Patch Functor
|
||||
|
||||
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
|
||||
|
||||

|
||||
|
||||
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
|
||||
|
||||
```python
|
||||
@BlockAdapterRegistry.register("HiDream")
|
||||
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
||||
|
||||
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
||||
return BlockAdapter(
|
||||
pipe=pipe,
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.double_stream_blocks,
|
||||
pipe.transformer.single_stream_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_0,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
# NOTE: Setup your custom patch functor here.
|
||||
patch_functor=HiDreamPatchFunctor(),
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
|
||||
|
||||
```python
|
||||
stats = cache_dit.summary(pipe)
|
||||
```
|
||||
|
||||
```python
|
||||
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
|
||||
|
||||
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
|
||||
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
|
||||
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
|
||||
```
|
||||
|
||||
## DBCache: Dual Block Cache
|
||||
|
||||

|
||||
|
||||
DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
|
||||
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
||||
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
||||
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe_or_adapter = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# Default options, F8B0, 8 warmup steps, and unlimited cached
|
||||
# steps for good balance between performance and precision
|
||||
cache_dit.enable_cache(pipe_or_adapter)
|
||||
|
||||
# Custom options, F8B8, higher precision
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
)
|
||||
```
|
||||
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
|
||||
|
||||
## TaylorSeer Calibrator
|
||||
|
||||
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
||||
|
||||
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
# Basic DBCache w/ FnBn configurations
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
# Then, you can use the TaylorSeer Calibrator to approximate
|
||||
# the values in cached steps, taylorseer_order default is 1.
|
||||
calibrator_config=TaylorSeerCalibratorConfig(
|
||||
taylorseer_order=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
|
||||
|
||||
## Hybrid Cache CFG
|
||||
|
||||
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
...,
|
||||
# For example, set it as True for Wan 2.1, Qwen-Image
|
||||
# and set it as False for FLUX.1, HunyuanVideo, etc.
|
||||
enable_separate_cfg=True,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
|
||||
|
||||
|
||||
```python
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Compile the Transformer module
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
```
|
||||
|
||||
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
|
||||
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit = 96 # default is 8
|
||||
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
||||
```
|
||||
|
||||
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
|
||||
@@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
|
||||
> [!WARNING]
|
||||
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
|
||||
|
||||
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
|
||||
|
||||
The `offload_type` parameter can be set to `block_level` or `leaf_level`.
|
||||
Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.
|
||||
|
||||
- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
|
||||
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.
|
||||
|
||||
Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.
|
||||
|
||||
<hfoptions id="group-offloading">
|
||||
<hfoption id="pipeline">
|
||||
|
||||
Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipeline.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="model">
|
||||
|
||||
Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
@@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### CUDA stream
|
||||
|
||||
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
|
||||
|
||||
@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
|
||||
> [!TIP]
|
||||
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
|
||||
|
||||
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
|
||||
- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
|
||||
|
||||
`components_to_quantize` accepts either a list for multiple models or a string for a single model.
|
||||
|
||||
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
|
||||
|
||||
@@ -62,6 +64,7 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
image = pipe("photo of a cute dog").images[0]
|
||||
```
|
||||
|
||||
|
||||
### Advanced quantization
|
||||
|
||||
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Models
|
||||
|
||||
A diffusion model relies on a few individual models working together to generate an output. These models are responsible for denoising, encoding inputs, and decoding latents into the actual outputs.
|
||||
|
||||
This guide will show you how to load models.
|
||||
|
||||
## Loading a model
|
||||
|
||||
All models are loaded with the [`~ModelMixin.from_pretrained`] method, which downloads and caches the latest model version. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache.
|
||||
|
||||
Pass the `subfolder` argument to [`~ModelMixin.from_pretrained`] to specify where to load the model weights from. Omit the `subfolder` argument if the repository doesn't have a subfolder structure or if you're loading a standalone model.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
```
|
||||
|
||||
## AutoModel
|
||||
|
||||
[`AutoModel`] detects the model class from a `model_index.json` file or a model's `config.json` file. It fetches the correct model class from these files and delegates the actual loading to the model class. [`AutoModel`] is useful for automatic model type detection without needing to know the exact model class beforehand.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
```
|
||||
|
||||
## Model data types
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to load a model with a specific data type. This allows you to load a model in a lower precision to reduce memory usage.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
[nn.Module.to](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to) can also convert to a specific data type on the fly. However, it converts *all* weights to the requested data type unlike `torch_dtype` which respects `_keep_in_fp32_modules`. This argument preserves layers in `torch.float32` for numerical stability and best generation quality (see example [_keep_in_fp32_modules](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
model = model.to(dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Device placement
|
||||
|
||||
Use the `device_map` argument in [`~ModelMixin.from_pretrained`] to place a model on an accelerator like a GPU. It is especially helpful where there are multiple GPUs.
|
||||
|
||||
Diffusers currently provides three options to `device_map` for individual models, `"cuda"`, `"balanced"` and `"auto"`. Refer to the table below to compare the three placement strategies.
|
||||
|
||||
| parameter | description |
|
||||
|---|---|
|
||||
| `"cuda"` | places pipeline on a supported accelerator (CUDA) |
|
||||
| `"balanced"` | evenly distributes pipeline on all GPUs |
|
||||
| `"auto"` | distribute model from fastest device first to slowest |
|
||||
|
||||
Use the `max_memory` argument in [`~ModelMixin.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
max_memory = {0: "16GB", 1: "16GB"}
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
max_memory=max_memory
|
||||
)
|
||||
```
|
||||
|
||||
The `hf_device_map` attribute allows you to access and view the `device_map`.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
# {'': device(type='cuda')}
|
||||
```
|
||||
|
||||
## Saving models
|
||||
|
||||
Save a model with the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
model.save_pretrained("./local/model")
|
||||
```
|
||||
|
||||
For large models, it is helpful to use `max_shard_size` to save a model as multiple shards. A shard can be loaded faster and save memory (refer to the [parallel loading](./loading#parallel-loading) docs for more details), especially if there is more than one GPU.
|
||||
|
||||
```py
|
||||
model.save_pretrained("./local/model", max_shard_size="5GB")
|
||||
```
|
||||
@@ -1,235 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Scheduler features
|
||||
|
||||
The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
|
||||
|
||||
This guide will demonstrate how to use these features to improve inference quality.
|
||||
|
||||
> [!TIP]
|
||||
> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
|
||||
|
||||
For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
|
||||
|
||||
```py
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
```
|
||||
|
||||
You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
|
||||
|
||||
- `leading` creates evenly spaced steps
|
||||
- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
|
||||
- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
|
||||
|
||||
It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Sigmas
|
||||
|
||||
The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
|
||||
|
||||
For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "anthropomorphic capybara wearing a suit and working with a computer"
|
||||
generator = torch.Generator(device='cuda').manual_seed(123)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=10,
|
||||
sigmas=sigmas,
|
||||
generator=generator
|
||||
).images[0]
|
||||
```
|
||||
|
||||
When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
|
||||
|
||||
```py
|
||||
print(f" timesteps: {pipe.scheduler.timesteps}")
|
||||
"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
|
||||
```
|
||||
|
||||
### Karras sigmas
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
|
||||
>
|
||||
> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
|
||||
|
||||
Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
|
||||
|
||||
Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Rescale noise schedule
|
||||
|
||||
In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
|
||||
|
||||
> [!TIP]
|
||||
> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
|
||||
>
|
||||
> ```bash
|
||||
> --prediction_type="v_prediction"
|
||||
> ```
|
||||
|
||||
For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
|
||||
|
||||
* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
|
||||
* `timestep_spacing="trailing"` to start sampling from the last timestep
|
||||
|
||||
Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
|
||||
generator = torch.Generator(device="cpu").manual_seed(23)
|
||||
image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
@@ -10,200 +10,273 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Load schedulers and models
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
|
||||
# Schedulers
|
||||
|
||||
This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
|
||||
A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
|
||||
|
||||
Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
|
||||
|
||||
This guide will show you how to load and customize schedulers.
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
|
||||
|
||||
```py
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
PNDMScheduler {
|
||||
"_class_name": "PNDMScheduler",
|
||||
"_diffusers_version": "0.21.4",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": null
|
||||
}
|
||||
```
|
||||
|
||||
## Load a scheduler
|
||||
|
||||
Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
|
||||
|
||||
For example, to load the [`DDIMScheduler`]:
|
||||
|
||||
```py
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
```
|
||||
|
||||
Then you can pass the newly loaded scheduler to the pipeline.
|
||||
|
||||
```python
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
|
||||
|
||||
Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
```
|
||||
|
||||
To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
|
||||
|
||||
<hfoptions id="schedulers">
|
||||
<hfoption id="LMSDiscreteScheduler">
|
||||
|
||||
[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
|
||||
|
||||
```py
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerDiscreteScheduler">
|
||||
|
||||
[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerAncestralDiscreteScheduler">
|
||||
|
||||
[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPMSolverMultistepScheduler">
|
||||
|
||||
[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
|
||||
Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
scheduler=dpm,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
```
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
|
||||
|
||||
> [!TIP]
|
||||
> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
|
||||
|
||||
Import the schedule and pass it to the `timesteps` argument in the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Rescaling schedules
|
||||
|
||||
Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
|
||||
|
||||
> [!TIP]
|
||||
> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
|
||||
|
||||
To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
|
||||
|
||||
- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
|
||||
- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
```
|
||||
|
||||
Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
|
||||
|
||||
```py
|
||||
prompt = """
|
||||
cinematic photo of a snowy mountain at night with the northern lights aurora borealis
|
||||
overhead, 35mm photograph, film, professional, 4k, highly detailed
|
||||
"""
|
||||
image = pipeline(prompt, guidance_rescale=0.7).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
|
||||
|
||||
| spacing strategy | spacing calculation | example timesteps |
|
||||
|---|---|---|
|
||||
| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
|
||||
| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
|
||||
| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
|
||||
|
||||
Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
|
||||
|
||||
> [!TIP]
|
||||
> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">LMSDiscreteScheduler</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerAncestralDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">DPMSolverMultistepScheduler</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
|
||||
## Sigmas
|
||||
|
||||
## Models
|
||||
Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
|
||||
|
||||
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
|
||||
> [!TIP]
|
||||
> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
|
||||
```
|
||||
|
||||
They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
|
||||
```
|
||||
|
||||
To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
|
||||
)
|
||||
unet.save_pretrained("./local-unet", variant="non_ema")
|
||||
```
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
|
||||
Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
|
||||
### Karras sigmas
|
||||
|
||||
[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
|
||||
|
||||
Set `use_karras_sigmas=True` in the scheduler to enable it.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config,
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
|
||||
|
||||
## Choosing a scheduler
|
||||
|
||||
It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
|
||||
|
||||
- DPM++ 2M SDE Karras is generally a good all-purpose option.
|
||||
- [`TCDScheduler`] works well for distilled models.
|
||||
- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
|
||||
- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
|
||||
- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
|
||||
|
||||
## Resources
|
||||
|
||||
- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
|
||||
@@ -98,7 +98,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16
|
||||
},
|
||||
components_to_quantize=["transformer"]
|
||||
components_to_quantize="transformer"
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
|
||||
@@ -1705,6 +1705,12 @@ class FaithDiffStableDiffusionXLPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
self.unet.denoise_encoder.enable_tiling()
|
||||
|
||||
@@ -1713,6 +1719,12 @@ class FaithDiffStableDiffusionXLPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
self.unet.denoise_encoder.disable_tiling()
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -643,6 +644,12 @@ class FluxKontextPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
||||
@@ -651,6 +658,12 @@ class FluxKontextPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def preprocess_image(self, image: PipelineImageInput, _auto_resize: bool, multiple_of: int) -> torch.Tensor:
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -526,6 +527,12 @@ class RFInversionFluxPipeline(
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -533,6 +540,12 @@ class RFInversionFluxPipeline(
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -541,6 +554,12 @@ class RFInversionFluxPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -548,6 +567,12 @@ class RFInversionFluxPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents_inversion(
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -702,6 +703,12 @@ class FluxSemanticGuidancePipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
||||
@@ -710,6 +717,12 @@ class FluxSemanticGuidancePipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
|
||||
@@ -28,6 +28,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -503,6 +504,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -510,6 +517,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -518,6 +531,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -525,6 +544,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
|
||||
@@ -29,11 +29,7 @@ from diffusers.models.transformers import SD3Transformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
|
||||
@@ -504,6 +504,12 @@ class StableDiffusionBoxDiffPipeline(
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -511,6 +517,12 @@ class StableDiffusionBoxDiffPipeline(
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -519,6 +531,12 @@ class StableDiffusionBoxDiffPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -526,6 +544,12 @@ class StableDiffusionBoxDiffPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -471,6 +471,12 @@ class StableDiffusionPAGPipeline(
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -478,6 +484,12 @@ class StableDiffusionPAGPipeline(
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -486,6 +498,12 @@ class StableDiffusionPAGPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -493,6 +511,12 @@ class StableDiffusionPAGPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
@@ -481,6 +481,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -488,6 +494,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -496,6 +508,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -503,6 +521,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
@property
|
||||
|
||||
@@ -26,11 +26,7 @@ from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
|
||||
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
@@ -458,6 +454,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -465,6 +467,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -473,6 +481,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -480,6 +494,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
|
||||
@@ -29,8 +29,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
@@ -1222,6 +1223,9 @@ def main(args):
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
@@ -1438,17 +1442,20 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
|
||||
model = unwrap_model(model)
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["text_encoder"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
FluxKontextPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
@@ -1461,15 +1468,25 @@ def main(args):
|
||||
transformer_ = None
|
||||
text_encoder_one_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
else:
|
||||
transformer_ = FluxTransformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer"
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
)
|
||||
|
||||
lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -2069,7 +2086,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -263,6 +263,12 @@ class PromptDiffusionPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
@@ -271,6 +277,12 @@ class PromptDiffusionPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetModels:
|
||||
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
||||
SD3_5: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TextToImagePipelineSD3:
|
||||
def __init__(self, model_path: str | None = None):
|
||||
self.model_path = model_path or os.getenv("MODEL_PATH")
|
||||
self.pipeline: StableDiffusion3Pipeline | None = None
|
||||
self.device: str | None = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
class ModelPipelineInitializer:
|
||||
def __init__(self, model: str = "", type_models: str = "t2im"):
|
||||
self.model = model
|
||||
self.type_models = type_models
|
||||
self.pipeline = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
self.model_type = None
|
||||
|
||||
def initialize_pipeline(self):
|
||||
if not self.model:
|
||||
raise ValueError("Model name not provided")
|
||||
|
||||
# Check if model exists in PresetModels
|
||||
preset_models = PresetModels()
|
||||
|
||||
# Determine which model type we're dealing with
|
||||
if self.model in preset_models.SD3:
|
||||
self.model_type = "SD3"
|
||||
elif self.model in preset_models.SD3_5:
|
||||
self.model_type = "SD3_5"
|
||||
|
||||
# Create appropriate pipeline based on model type and type_models
|
||||
if self.type_models == "t2im":
|
||||
if self.model_type in ["SD3", "SD3_5"]:
|
||||
self.pipeline = TextToImagePipelineSD3(self.model)
|
||||
else:
|
||||
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
||||
elif self.type_models == "t2v":
|
||||
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
||||
|
||||
return self.pipeline
|
||||
@@ -0,0 +1,171 @@
|
||||
# Asynchronous server and parallel execution of models
|
||||
|
||||
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
|
||||
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
|
||||
|
||||
## ⚠️ IMPORTANT
|
||||
|
||||
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
|
||||
|
||||
## Necessary components
|
||||
|
||||
All the components needed to create the inference server are in the current directory:
|
||||
|
||||
```
|
||||
server-async/
|
||||
├── utils/
|
||||
├─────── __init__.py
|
||||
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
|
||||
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
|
||||
├─────── utils.py # Image/video saving utilities and service configuration
|
||||
├── Pipelines.py # pipeline loader classes (SD3)
|
||||
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
|
||||
├── test.py # Client test script for inference requests
|
||||
├── requirements.txt # Dependencies
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## What `diffusers-async` adds / Why we needed it
|
||||
|
||||
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
|
||||
|
||||
`diffusers-async` / this example addresses that by:
|
||||
|
||||
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
|
||||
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
|
||||
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
|
||||
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
|
||||
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
|
||||
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
|
||||
|
||||
## How the server works (high-level flow)
|
||||
|
||||
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
|
||||
2. On each HTTP inference request:
|
||||
|
||||
* The server uses `RequestScopedPipeline.generate(...)` which:
|
||||
|
||||
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
|
||||
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
|
||||
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
|
||||
* sets `local_pipe.scheduler = local_scheduler` (if possible),
|
||||
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
|
||||
* wraps tokenizers with thread-safe locks to prevent race conditions,
|
||||
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
|
||||
* calls the pipeline on the local view (`local_pipe(...)`).
|
||||
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
|
||||
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
|
||||
|
||||
## How to set up and run the server
|
||||
|
||||
### 1) Install dependencies
|
||||
|
||||
Recommended: create a virtualenv / conda environment.
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2) Start the server
|
||||
|
||||
Using the `serverasync.py` file that already has everything you need:
|
||||
|
||||
```bash
|
||||
python serverasync.py
|
||||
```
|
||||
|
||||
The server will start on `http://localhost:8500` by default with the following features:
|
||||
- FastAPI application with async lifespan management
|
||||
- Automatic model loading and pipeline initialization
|
||||
- Request counting and active inference tracking
|
||||
- Memory cleanup after each inference
|
||||
- CORS middleware for cross-origin requests
|
||||
|
||||
### 3) Test the server
|
||||
|
||||
Use the included test script:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
Or send a manual request:
|
||||
|
||||
`POST /api/diffusers/inference` with JSON body:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "A futuristic cityscape, vibrant colors",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
|
||||
```json
|
||||
{
|
||||
"response": ["http://localhost:8500/images/img123.png"]
|
||||
}
|
||||
```
|
||||
|
||||
### 4) Server endpoints
|
||||
|
||||
- `GET /` - Welcome message
|
||||
- `POST /api/diffusers/inference` - Main inference endpoint
|
||||
- `GET /images/{filename}` - Serve generated images
|
||||
- `GET /api/status` - Server status and memory info
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### RequestScopedPipeline Parameters
|
||||
|
||||
```python
|
||||
RequestScopedPipeline(
|
||||
pipeline, # Base pipeline to wrap
|
||||
mutable_attrs=None, # Custom list of attributes to clone
|
||||
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
|
||||
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
|
||||
tokenizer_lock=None, # Custom threading lock for tokenizers
|
||||
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
|
||||
)
|
||||
```
|
||||
|
||||
### BaseAsyncScheduler Features
|
||||
|
||||
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
|
||||
* `clone_for_request()` method for safe per-request scheduler cloning
|
||||
* Enhanced debugging with `__repr__` and `__str__` methods
|
||||
* Full compatibility with existing scheduler APIs
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = 'stabilityai/stable-diffusion-3.5-medium'
|
||||
type_models: str = 't2im'
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8500
|
||||
```
|
||||
|
||||
## Troubleshooting (quick)
|
||||
|
||||
* `Already borrowed` — previously a Rust tokenizer concurrency error.
|
||||
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
|
||||
|
||||
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
|
||||
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
|
||||
|
||||
* Scheduler issues:
|
||||
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
|
||||
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
|
||||
|
||||
* Memory issues with large tensors:
|
||||
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
|
||||
|
||||
* Automatic tokenizer detection:
|
||||
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
|
||||
@@ -0,0 +1,10 @@
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
sentencepiece
|
||||
fastapi
|
||||
uvicorn
|
||||
ftfy
|
||||
accelerate
|
||||
xformers
|
||||
protobuf
|
||||
@@ -0,0 +1,230 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from Pipelines import ModelPipelineInitializer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import RequestScopedPipeline, Utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
||||
type_models: str = "t2im"
|
||||
constructor_pipeline: Optional[Type] = None
|
||||
custom_pipeline: Optional[Type] = None
|
||||
components: Optional[Dict[str, Any]] = None
|
||||
torch_dtype: Optional[torch.dtype] = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8500
|
||||
|
||||
|
||||
server_config = ServerConfigModels()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
app.state.logger = logging.getLogger("diffusers-server")
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
|
||||
app.state.total_requests = 0
|
||||
app.state.active_inferences = 0
|
||||
app.state.metrics_lock = asyncio.Lock()
|
||||
app.state.metrics_task = None
|
||||
|
||||
app.state.utils_app = Utils(
|
||||
host=server_config.host,
|
||||
port=server_config.port,
|
||||
)
|
||||
|
||||
async def metrics_loop():
|
||||
try:
|
||||
while True:
|
||||
async with app.state.metrics_lock:
|
||||
total = app.state.total_requests
|
||||
active = app.state.active_inferences
|
||||
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
app.state.logger.info("Metrics loop cancelled")
|
||||
raise
|
||||
|
||||
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task = app.state.metrics_task
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
||||
if callable(stop_fn):
|
||||
await run_in_threadpool(stop_fn)
|
||||
except Exception as e:
|
||||
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
||||
|
||||
app.state.logger.info("Lifespan shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger = logging.getLogger("DiffusersServer.Pipelines")
|
||||
|
||||
|
||||
initializer = ModelPipelineInitializer(
|
||||
model=server_config.model,
|
||||
type_models=server_config.type_models,
|
||||
)
|
||||
model_pipeline = initializer.initialize_pipeline()
|
||||
model_pipeline.start()
|
||||
|
||||
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
||||
pipeline_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
||||
|
||||
app.state.MODEL_INITIALIZER = initializer
|
||||
app.state.MODEL_PIPELINE = model_pipeline
|
||||
app.state.REQUEST_PIPE = request_pipe
|
||||
app.state.PIPELINE_LOCK = pipeline_lock
|
||||
|
||||
|
||||
class JSONBodyQueryAPI(BaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
num_inference_steps: int = 28
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def count_requests_middleware(request: Request, call_next):
|
||||
async with app.state.metrics_lock:
|
||||
app.state.total_requests += 1
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Diffusers Server"}
|
||||
|
||||
|
||||
@app.post("/api/diffusers/inference")
|
||||
async def api(json: JSONBodyQueryAPI):
|
||||
prompt = json.prompt
|
||||
negative_prompt = json.negative_prompt or ""
|
||||
num_steps = json.num_inference_steps
|
||||
num_images_per_prompt = json.num_images_per_prompt
|
||||
|
||||
wrapper = app.state.MODEL_PIPELINE
|
||||
initializer = app.state.MODEL_INITIALIZER
|
||||
|
||||
utils_app = app.state.utils_app
|
||||
|
||||
if not wrapper or not wrapper.pipeline:
|
||||
raise HTTPException(500, "Model not initialized correctly")
|
||||
if not prompt.strip():
|
||||
raise HTTPException(400, "No prompt provided")
|
||||
|
||||
def make_generator():
|
||||
g = torch.Generator(device=initializer.device)
|
||||
return g.manual_seed(random.randint(0, 10_000_000))
|
||||
|
||||
req_pipe = app.state.REQUEST_PIPE
|
||||
|
||||
def infer():
|
||||
gen = make_generator()
|
||||
return req_pipe.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=gen,
|
||||
num_inference_steps=num_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=initializer.device,
|
||||
output_type="pil",
|
||||
)
|
||||
|
||||
try:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences += 1
|
||||
|
||||
output = await run_in_threadpool(infer)
|
||||
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
|
||||
urls = [utils_app.save_image(img) for img in output.images]
|
||||
return {"response": urls}
|
||||
|
||||
except Exception as e:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise HTTPException(500, f"Error in processing: {e}")
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@app.get("/images/{filename}")
|
||||
async def serve_image(filename: str):
|
||||
utils_app = app.state.utils_app
|
||||
file_path = os.path.join(utils_app.image_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
memory_info = {}
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
||||
memory_info = {
|
||||
"memory_allocated_gb": round(memory_allocated, 2),
|
||||
"memory_reserved_gb": round(memory_reserved, 2),
|
||||
"device": torch.cuda.get_device_name(0),
|
||||
}
|
||||
|
||||
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=server_config.host, port=server_config.port)
|
||||
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
DOWNLOAD_FOLDER = "generated_images"
|
||||
WAIT_BEFORE_DOWNLOAD = 2 # seconds
|
||||
|
||||
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def save_from_url(url: str) -> str:
|
||||
"""Download the given URL (relative or absolute) and save it locally."""
|
||||
if url.startswith("/"):
|
||||
direct = BASE_URL.rstrip("/") + url
|
||||
else:
|
||||
direct = url
|
||||
resp = requests.get(direct, timeout=60)
|
||||
resp.raise_for_status()
|
||||
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
|
||||
path = os.path.join(DOWNLOAD_FOLDER, filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return path
|
||||
|
||||
|
||||
def main():
|
||||
payload = {
|
||||
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1,
|
||||
}
|
||||
|
||||
print("Sending request...")
|
||||
try:
|
||||
r = requests.post(SERVER_URL, json=payload, timeout=480)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return
|
||||
|
||||
body = r.json().get("response", [])
|
||||
# Normalize to a list
|
||||
urls = body if isinstance(body, list) else [body] if body else []
|
||||
if not urls:
|
||||
print("No URLs found in the response. Check the server output.")
|
||||
return
|
||||
|
||||
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
|
||||
time.sleep(WAIT_BEFORE_DOWNLOAD)
|
||||
|
||||
for u in urls:
|
||||
try:
|
||||
path = save_from_url(u)
|
||||
print(f"Image saved to: {path}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {u}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
from .requestscopedpipeline import RequestScopedPipeline
|
||||
from .utils import Utils
|
||||
@@ -0,0 +1,296 @@
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
"_offload_device",
|
||||
"_progress_bar_config",
|
||||
"_progress_bar",
|
||||
"_rng_state",
|
||||
"_last_seed",
|
||||
"latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: Any,
|
||||
mutable_attrs: Optional[Iterable[str]] = None,
|
||||
auto_detect_mutables: bool = True,
|
||||
tensor_numel_threshold: int = 1_000_000,
|
||||
tokenizer_lock: Optional[threading.Lock] = None,
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
return None
|
||||
|
||||
if not isinstance(base_sched, BaseAsyncScheduler):
|
||||
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
||||
else:
|
||||
wrapped_scheduler = base_sched
|
||||
|
||||
try:
|
||||
return wrapped_scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
if not self._auto_detect_mutables:
|
||||
return []
|
||||
|
||||
if self._auto_detected_attrs:
|
||||
return self._auto_detected_attrs
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
if name in self._mutable_attrs:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
# try Tensor detection
|
||||
try:
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(candidates) >= max_attrs:
|
||||
break
|
||||
|
||||
self._auto_detected_attrs = candidates
|
||||
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
||||
return self._auto_detected_attrs
|
||||
|
||||
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
||||
try:
|
||||
cls = type(base_obj)
|
||||
descriptor = getattr(cls, attr_name, None)
|
||||
if isinstance(descriptor, property):
|
||||
return descriptor.fset is None
|
||||
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _clone_mutable_attrs(self, base, local):
|
||||
attrs_to_clone = list(self._mutable_attrs)
|
||||
attrs_to_clone.extend(self._autodetect_mutables())
|
||||
|
||||
EXCLUDE_ATTRS = {
|
||||
"components",
|
||||
}
|
||||
|
||||
for attr in attrs_to_clone:
|
||||
if attr in EXCLUDE_ATTRS:
|
||||
logger.debug(f"Skipping excluded attr '{attr}'")
|
||||
continue
|
||||
if not hasattr(base, attr):
|
||||
continue
|
||||
if self._is_readonly_property(base, attr):
|
||||
logger.debug(f"Skipping read-only property '{attr}'")
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(base, attr)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
setattr(local, attr, dict(val))
|
||||
elif isinstance(val, (list, tuple, set)):
|
||||
setattr(local, attr, list(val))
|
||||
elif isinstance(val, bytearray):
|
||||
setattr(local, attr, bytearray(val))
|
||||
else:
|
||||
# small tensors or atomic values
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
setattr(local, attr, val.clone())
|
||||
else:
|
||||
# don't clone big tensors, keep reference
|
||||
setattr(local, attr, val)
|
||||
else:
|
||||
try:
|
||||
setattr(local, attr, copy.copy(val))
|
||||
except Exception:
|
||||
setattr(local, attr, val)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
||||
continue
|
||||
|
||||
def _is_tokenizer_component(self, component) -> bool:
|
||||
if component is None:
|
||||
return False
|
||||
|
||||
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
||||
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
||||
|
||||
class_name = component.__class__.__name__.lower()
|
||||
has_tokenizer_in_name = "tokenizer" in class_name
|
||||
|
||||
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
||||
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
try:
|
||||
local_pipe = copy.copy(self._base)
|
||||
except Exception as e:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
local_scheduler.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
return_scheduler=True,
|
||||
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
||||
)
|
||||
|
||||
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
||||
setattr(local_pipe, "scheduler", final_scheduler)
|
||||
except Exception:
|
||||
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except Exception as e:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
@@ -0,0 +1,141 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseAsyncScheduler:
|
||||
def __init__(self, scheduler: Any):
|
||||
self.scheduler = scheduler
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if hasattr(self.scheduler, name):
|
||||
return getattr(self.scheduler, name)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
if name == "scheduler":
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
|
||||
setattr(self.scheduler, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
|
||||
local = copy.deepcopy(self.scheduler)
|
||||
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
|
||||
cloned = self.__class__(local)
|
||||
return cloned
|
||||
|
||||
def __repr__(self):
|
||||
return f"BaseAsyncScheduler({repr(self.scheduler)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
|
||||
|
||||
|
||||
def async_retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
|
||||
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Backwards compatible: by default the function behaves exactly as before and returns
|
||||
(timesteps_tensor, num_inference_steps)
|
||||
|
||||
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
|
||||
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
|
||||
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
|
||||
(timesteps_tensor, num_inference_steps, scheduler_in_use)
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Optional kwargs:
|
||||
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
|
||||
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
|
||||
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
|
||||
|
||||
Returns:
|
||||
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
|
||||
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
|
||||
"""
|
||||
# pop our optional control kwarg (keeps compatibility)
|
||||
return_scheduler = bool(kwargs.pop("return_scheduler", False))
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
|
||||
# choose scheduler to call set_timesteps on
|
||||
scheduler_in_use = scheduler
|
||||
if return_scheduler:
|
||||
# Do not mutate the provided scheduler: prefer to clone if possible
|
||||
if hasattr(scheduler, "clone_for_request"):
|
||||
try:
|
||||
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
|
||||
scheduler_in_use = scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps or 0, device=device
|
||||
)
|
||||
except Exception:
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
else:
|
||||
# fallback deepcopy (scheduler tends to be smallish - acceptable)
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
|
||||
# helper to test if set_timesteps supports a particular kwarg
|
||||
def _accepts(param_name: str) -> bool:
|
||||
try:
|
||||
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
|
||||
except (ValueError, TypeError):
|
||||
# if signature introspection fails, be permissive and attempt the call later
|
||||
return False
|
||||
|
||||
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = _accepts("timesteps")
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = _accepts("sigmas")
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
else:
|
||||
# default path
|
||||
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
|
||||
if return_scheduler:
|
||||
return timesteps_out, num_inference_steps, scheduler_in_use
|
||||
return timesteps_out, num_inference_steps
|
||||
@@ -0,0 +1,48 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Utils:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
||||
self.service_url = f"http://{host}:{port}"
|
||||
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(self.image_dir):
|
||||
os.makedirs(self.image_dir)
|
||||
|
||||
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
||||
if not os.path.exists(self.video_dir):
|
||||
os.makedirs(self.video_dir)
|
||||
|
||||
def save_image(self, image):
|
||||
if hasattr(image, "to"):
|
||||
try:
|
||||
image = image.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
from torchvision import transforms
|
||||
|
||||
to_pil = transforms.ToPILImage()
|
||||
image = to_pil(image.squeeze(0).clamp(0, 1))
|
||||
|
||||
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(self.image_dir, filename)
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
|
||||
image.save(image_path, format="PNG", optimize=True)
|
||||
|
||||
del image
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return os.path.join(self.service_url, "images", filename)
|
||||
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
@@ -6,4 +6,5 @@ py-consul
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
uvicorn
|
||||
accelerate
|
||||
|
||||
@@ -39,7 +39,7 @@ fsspec==2024.10.0
|
||||
# torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.26.1
|
||||
huggingface-hub==0.35.0
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
|
||||
@@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-VACE-Fun-14B":
|
||||
config = {
|
||||
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
@@ -975,7 +998,17 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "VACE" in args.model_type:
|
||||
elif "Wan2.2-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.875,
|
||||
)
|
||||
elif "Wan-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -102,7 +102,8 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.34.0",
|
||||
"httpx<1.0.0",
|
||||
"huggingface-hub>=0.34.0,<2.0",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
@@ -259,6 +260,7 @@ extras["dev"] = (
|
||||
install_requires = [
|
||||
deps["importlib_metadata"],
|
||||
deps["filelock"],
|
||||
deps["httpx"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["regex"],
|
||||
|
||||
@@ -202,6 +202,7 @@ else:
|
||||
"CogView4Transformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ContextParallelConfig",
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
@@ -229,6 +230,7 @@ else:
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageControlNetModel",
|
||||
@@ -385,6 +387,10 @@ else:
|
||||
[
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"WanAutoBlocks",
|
||||
@@ -491,6 +497,7 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
@@ -506,9 +513,11 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImagePipeline",
|
||||
@@ -881,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Transformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
ConsistencyDecoderVAE,
|
||||
ContextParallelConfig,
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
@@ -908,6 +918,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageControlNetModel,
|
||||
@@ -1038,6 +1049,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modular_pipelines import (
|
||||
FluxAutoBlocks,
|
||||
FluxModularPipeline,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
WanAutoBlocks,
|
||||
@@ -1140,6 +1155,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
@@ -1155,9 +1171,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -30,11 +30,11 @@ import numpy as np
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import __version__
|
||||
@@ -419,7 +419,7 @@ class ConfigMixin:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
except HfHubHTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
|
||||
@@ -9,7 +9,8 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0",
|
||||
"httpx": "httpx<1.0.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .context_parallel import apply_context_parallel
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
|
||||
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
|
||||
# AttnProcessor2_0
|
||||
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
|
||||
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
|
||||
)
|
||||
|
||||
# QwenDoubleStreamAttnProcessor2
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=QwenDoubleStreamAttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -298,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h
|
||||
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
# not sure what this is yet.
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..models._modeling_parallel import (
|
||||
ContextParallelConfig,
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
||||
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
|
||||
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier], True, None
|
||||
|
||||
if self.cached_parameter_indices is not None:
|
||||
index = self.cached_parameter_indices.get(identifier, None)
|
||||
if index is None:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
|
||||
return args[index], False, index
|
||||
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
|
||||
if identifier not in self.cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
|
||||
index = self.cached_parameter_indices[identifier]
|
||||
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
|
||||
return args[index], False, index
|
||||
|
||||
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ContextParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
||||
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
|
||||
|
||||
for m in submodule:
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
if isinstance(cp_model_plan, ContextParallelOutput):
|
||||
cp_model_plan = [cp_model_plan]
|
||||
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
|
||||
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
|
||||
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
for m in submodule:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry.remove_hook(hook_name)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
self.module_forward_metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
cls = unwrap_module(module).__class__
|
||||
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
args_list = list(args)
|
||||
|
||||
for name, cpm in self.metadata.items():
|
||||
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
|
||||
continue
|
||||
|
||||
# Maybe the parameter was passed as a keyword argument
|
||||
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
|
||||
name, args_list, kwargs
|
||||
)
|
||||
|
||||
if input_val is None:
|
||||
continue
|
||||
|
||||
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
|
||||
# the output instead of input for a particular layer by setting split_output=True
|
||||
if isinstance(input_val, torch.Tensor):
|
||||
input_val = self._prepare_cp_input(input_val, cpm)
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
if len(input_val) != len(cpm):
|
||||
raise ValueError(
|
||||
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
|
||||
)
|
||||
sharded_input_val = []
|
||||
for i, x in enumerate(input_val):
|
||||
if torch.is_tensor(x) and not cpm[i].split_output:
|
||||
x = self._prepare_cp_input(x, cpm[i])
|
||||
sharded_input_val.append(x)
|
||||
input_val = sharded_input_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(input_val)}")
|
||||
|
||||
if is_kwarg:
|
||||
kwargs[name] = input_val
|
||||
elif index is not None and index < len(args_list):
|
||||
args_list[index] = input_val
|
||||
else:
|
||||
raise ValueError(
|
||||
f"An unexpected error occurred while processing the input '{name}'. Please open an "
|
||||
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
|
||||
f"example along with the full stack trace."
|
||||
)
|
||||
|
||||
return tuple(args_list), kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
|
||||
|
||||
if not is_tensor and not is_tensor_list:
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = [output] if is_tensor else list(output)
|
||||
for index, cpm in self.metadata.items():
|
||||
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
|
||||
continue
|
||||
if index >= len(output):
|
||||
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
|
||||
current_output = output[index]
|
||||
current_output = self._prepare_cp_input(current_output, cpm)
|
||||
output[index] = current_output
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
|
||||
if is_tensor:
|
||||
output = [output]
|
||||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = list(output)
|
||||
|
||||
if len(output) != len(self.metadata):
|
||||
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
|
||||
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
|
||||
class AllGatherFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, dim, group):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = torch.distributed.get_world_size(group)
|
||||
ctx.rank = torch.distributed.get_rank(group)
|
||||
return funcol.all_gather_tensor(tensor, dim, group=group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_chunks[ctx.rank], None, None
|
||||
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
|
||||
# because the alternate case has not yet been tested/required for any model.
|
||||
assert tensor.size()[dim] % mesh.size() == 0, (
|
||||
"Tensor size along dimension to be sharded must be divisible by mesh size"
|
||||
)
|
||||
|
||||
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
|
||||
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
if first_atom == "*":
|
||||
if not isinstance(model, torch.nn.ModuleList):
|
||||
raise ValueError("Wildcard '*' can only be used with ModuleList")
|
||||
submodules = []
|
||||
for submodule in model:
|
||||
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
|
||||
if not isinstance(subsubmodules, list):
|
||||
subsubmodules = [subsubmodules]
|
||||
submodules.extend(subsubmodules)
|
||||
return submodules
|
||||
else:
|
||||
if hasattr(model, first_atom):
|
||||
submodule = getattr(model, first_atom)
|
||||
return _find_submodule_by_name(submodule, remaining_name)
|
||||
else:
|
||||
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
|
||||
@@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
@@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return image
|
||||
|
||||
|
||||
class InpaintProcessor(ConfigMixin):
|
||||
"""
|
||||
Image processor for inpainting image and mask.
|
||||
"""
|
||||
|
||||
config_name = CONFIG_NAME
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 8,
|
||||
vae_latent_channels: int = 4,
|
||||
resample: str = "lanczos",
|
||||
reducing_gap: int = None,
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_grayscale: bool = False,
|
||||
mask_do_normalize: bool = False,
|
||||
mask_do_binarize: bool = True,
|
||||
mask_do_convert_grayscale: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._image_processor = VaeImageProcessor(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
resample=resample,
|
||||
reducing_gap=reducing_gap,
|
||||
do_normalize=do_normalize,
|
||||
do_binarize=do_binarize,
|
||||
do_convert_grayscale=do_convert_grayscale,
|
||||
)
|
||||
self._mask_processor = VaeImageProcessor(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
resample=resample,
|
||||
reducing_gap=reducing_gap,
|
||||
do_normalize=mask_do_normalize,
|
||||
do_binarize=mask_do_binarize,
|
||||
do_convert_grayscale=mask_do_convert_grayscale,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
mask: PIL.Image.Image = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the image and mask.
|
||||
"""
|
||||
if mask is None and padding_mask_crop is not None:
|
||||
raise ValueError("mask must be provided if padding_mask_crop is provided")
|
||||
|
||||
# if mask is None, same behavior as regular image processor
|
||||
if mask is None:
|
||||
return self._image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
processed_image = self._image_processor.preprocess(
|
||||
image,
|
||||
height=height,
|
||||
width=width,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
processed_mask = self._mask_processor.preprocess(
|
||||
mask,
|
||||
height=height,
|
||||
width=width,
|
||||
resize_mode=resize_mode,
|
||||
crops_coords=crops_coords,
|
||||
)
|
||||
|
||||
if crops_coords is not None:
|
||||
postprocessing_kwargs = {
|
||||
"crops_coords": crops_coords,
|
||||
"original_image": image,
|
||||
"original_mask": mask,
|
||||
}
|
||||
else:
|
||||
postprocessing_kwargs = {
|
||||
"crops_coords": None,
|
||||
"original_image": None,
|
||||
"original_mask": None,
|
||||
}
|
||||
|
||||
return processed_image, processed_mask, postprocessing_kwargs
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
original_image: Optional[PIL.Image.Image] = None,
|
||||
original_mask: Optional[PIL.Image.Image] = None,
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
"""
|
||||
Postprocess the image, optionally apply mask overlay
|
||||
"""
|
||||
image = self._image_processor.postprocess(
|
||||
image,
|
||||
output_type=output_type,
|
||||
)
|
||||
# optionally apply the mask overlay
|
||||
if crops_coords is not None and (original_image is None or original_mask is None):
|
||||
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
|
||||
|
||||
elif crops_coords is not None and output_type != "pil":
|
||||
raise ValueError("output_type must be 'pil' if crops_coords is provided")
|
||||
|
||||
elif crops_coords is not None:
|
||||
image = [
|
||||
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
|
||||
]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
Image processor for VAE LDM3D.
|
||||
|
||||
@@ -1064,6 +1064,41 @@ class LoraBaseMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
|
||||
pipeline types.
|
||||
"""
|
||||
state_dict = {}
|
||||
final_lora_adapter_metadata = {}
|
||||
|
||||
for prefix, layers in lora_layers.items():
|
||||
state_dict.update(cls.pack_weights(layers, prefix))
|
||||
|
||||
for prefix, metadata in lora_metadata.items():
|
||||
if metadata:
|
||||
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_in_layer",
|
||||
"time_text_embed.guidance_embedder.linear_1",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_out_layer",
|
||||
"time_text_embed.guidance_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_img_in",
|
||||
"x_embedder",
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_txt_in",
|
||||
"context_embedder",
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_in_layer",
|
||||
"time_text_embed.timestep_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_out_layer",
|
||||
"time_text_embed.timestep_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_in_layer",
|
||||
"time_text_embed.text_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_out_layer",
|
||||
"time_text_embed.text_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
|
||||
+256
-2362
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
@@ -297,6 +298,7 @@ class FromOriginalModelMixin:
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
@@ -403,19 +405,8 @@ class FromOriginalModelMixin:
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
diffusers_format_checkpoint = checkpoint
|
||||
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
@@ -428,6 +419,26 @@ class FromOriginalModelMixin:
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
# Now that the model is loaded, we can determine the `device_map`
|
||||
device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
|
||||
_caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
diffusers_format_checkpoint = checkpoint
|
||||
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
@@ -119,6 +120,7 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
|
||||
# Experimental changes are subject to change and APIs may break without warning.
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# TODO(aryan): add support for the following:
|
||||
# - Unified Attention
|
||||
# - More dispatcher attention backends
|
||||
# - CFG/Data Parallel
|
||||
# - Tensor Parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextParallelConfig:
|
||||
"""
|
||||
Configuration for context parallelism.
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
|
||||
"""
|
||||
|
||||
ring_degree: Optional[int] = None
|
||||
ulysses_degree: Optional[int] = None
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_local_rank: int = None
|
||||
_ulysses_local_rank: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for applying different parallelisms.
|
||||
|
||||
Args:
|
||||
context_parallel_config (`ContextParallelConfig`, *optional*):
|
||||
Configuration for context parallelism.
|
||||
"""
|
||||
|
||||
context_parallel_config: Optional[ContextParallelConfig] = None
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelInput:
|
||||
"""
|
||||
Configuration for splitting an input tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
split_dim (`int`):
|
||||
The dimension along which to split the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before splitting.
|
||||
split_output (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
|
||||
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
|
||||
RoPE).
|
||||
"""
|
||||
|
||||
split_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
split_output: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelOutput:
|
||||
"""
|
||||
Configuration for gathering an output tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
gather_dim (`int`):
|
||||
The dimension along which to gather the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before gathering.
|
||||
"""
|
||||
|
||||
gather_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
|
||||
|
||||
|
||||
# A dictionary where keys denote the input to be split across context parallel region, and the
|
||||
# value denotes the sharding configuration.
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
|
||||
|
||||
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
||||
#
|
||||
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
|
||||
# tensors at different stages of the forward:
|
||||
#
|
||||
# ```python
|
||||
# _cp_plan = {
|
||||
# "": {
|
||||
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
# },
|
||||
# "pos_embed": {
|
||||
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# },
|
||||
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
# }
|
||||
# ```
|
||||
#
|
||||
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
|
||||
# split/gathered according to this at the respective module level. Here, the following happens:
|
||||
# - "":
|
||||
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
|
||||
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
|
||||
# - "pos_embed":
|
||||
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
|
||||
# we can individually specify how they should be split
|
||||
# - "proj_out":
|
||||
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
|
||||
# layer forward has run).
|
||||
#
|
||||
# ContextParallelInput:
|
||||
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
|
||||
#
|
||||
# ContextParallelOutput:
|
||||
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
||||
@@ -241,7 +241,7 @@ class AttentionModuleMixin:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xops.memory_efficient_attention(q, q, q)
|
||||
_ = xops.ops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -189,15 +192,35 @@ class AutoModel(ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
model_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_or_path,
|
||||
subfolder=subfolder,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
@@ -617,7 +617,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.size(0) > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
is_residual=is_residual,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
self.spatial_compression_ratio = scale_factor_spatial
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor):
|
||||
_, _, num_frame, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
|
||||
@@ -26,11 +26,11 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__, is_torch_available
|
||||
from ..utils import (
|
||||
@@ -385,7 +385,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
except HfHubHTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
|
||||
@@ -65,6 +65,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
_parallel_config = None
|
||||
_cp_plan = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -620,8 +623,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||
the torch native scaled dot product attention.
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||
set, or the torch native scaled dot product attention.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -960,6 +963,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1340,6 +1344,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if parallel_config is not None:
|
||||
model.enable_parallelism(config=parallel_config)
|
||||
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
@@ -1478,6 +1485,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
def enable_parallelism(
|
||||
self,
|
||||
*,
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_parallel_config"):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
|
||||
return selected_indices
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
id_cond: Optional[torch.Tensor] = None,
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
|
||||
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
class BriaAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -161,7 +162,12 @@ class BriaAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -472,7 +478,7 @@ class BriaSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
@@ -588,7 +594,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`BriaTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`CogView3PlusTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
|
||||
|
||||
class FluxAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -114,7 +116,12 @@ class FluxAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
@@ -556,6 +566,15 @@ class FluxTransformer2DModel(
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wtype = hidden_states.dtype
|
||||
(
|
||||
shift_msa_i,
|
||||
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self.block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
|
||||
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -334,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
@@ -502,6 +505,18 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
|
||||
|
||||
class SkyReelsV2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
|
||||
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -132,6 +134,7 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -144,6 +147,7 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -539,6 +543,19 @@ class WanTransformer3DModel(
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
},
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -82,6 +82,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
@@ -100,15 +101,23 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
||||
)
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -47,6 +47,12 @@ else:
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageModularPipeline",
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -68,6 +74,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
|
||||
from .qwenimage import (
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import get_device
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -161,7 +162,9 @@ class AutoOffloadStrategy:
|
||||
|
||||
current_module_size = model.get_memory_footprint()
|
||||
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -301,7 +304,7 @@ class ComponentsManager:
|
||||
cm.add("vae", vae_model, collection="sdxl")
|
||||
|
||||
# Enable auto offloading
|
||||
cm.enable_auto_cpu_offload(device="cuda")
|
||||
cm.enable_auto_cpu_offload()
|
||||
|
||||
# Retrieve components
|
||||
unet = cm.get_one(name="unet", collection="sdxl")
|
||||
@@ -490,6 +493,8 @@ class ComponentsManager:
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# YiYi TODO: rename to search_components for now, may remove this method
|
||||
def search_components(
|
||||
@@ -678,7 +683,7 @@ class ComponentsManager:
|
||||
|
||||
return get_return_dict(matches, return_dict_with_names)
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
|
||||
"""
|
||||
Enable automatic CPU offloading for all components.
|
||||
|
||||
@@ -704,6 +709,8 @@ class ComponentsManager:
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
|
||||
@@ -454,6 +454,9 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
@@ -659,8 +662,6 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
@@ -148,8 +148,8 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
|
||||
@@ -56,6 +56,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -64,6 +66,8 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
|
||||
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
|
||||
("WanModularPipeline", "WanAutoBlocks"),
|
||||
("FluxModularPipeline", "FluxAutoBlocks"),
|
||||
("QwenImageModularPipeline", "QwenImageAutoBlocks"),
|
||||
("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -133,8 +137,8 @@ class PipelineState:
|
||||
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
|
||||
intermediates dict.
|
||||
"""
|
||||
if name in self.intermediates:
|
||||
return self.intermediates[name]
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -319,7 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
@@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
@@ -830,7 +837,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return expected_configs
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
|
||||
def from_blocks_dict(
|
||||
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
|
||||
) -> "SequentialPipelineBlocks":
|
||||
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
|
||||
|
||||
Args:
|
||||
@@ -852,12 +861,19 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
|
||||
instance.block_names = list(sub_blocks.keys())
|
||||
instance.sub_blocks = sub_blocks
|
||||
|
||||
if description is not None:
|
||||
instance.description = description
|
||||
|
||||
return instance
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
def _get_inputs(self):
|
||||
@@ -1280,8 +1296,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"EDIT_BLOCKS",
|
||||
"EDIT_INPAINT_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
EDIT_AUTO_BLOCKS,
|
||||
EDIT_BLOCKS,
|
||||
EDIT_INPAINT_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,727 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
# Prepare Latents steps
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare initial random noise for the generation process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
vae_scale_factor=components.vae_scale_factor,
|
||||
)
|
||||
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# we can update the height and width here since it's used to generate the initial
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised, can be generated in prepare latent step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="initial_noise",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised used for inpainting denoising.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(image_latents, latents):
|
||||
if image_latents.shape[0] != latents.shape[0]:
|
||||
raise ValueError(
|
||||
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
|
||||
)
|
||||
|
||||
if image_latents.ndim != 3:
|
||||
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
image_latents=block_state.image_latents,
|
||||
latents=block_state.latents,
|
||||
)
|
||||
|
||||
# prepare latent timestep
|
||||
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
|
||||
|
||||
# make copy of initial_noise
|
||||
block_state.initial_noise = block_state.latents
|
||||
|
||||
# scale noise
|
||||
block_state.latents = components.scheduler.scale_noise(
|
||||
block_state.image_latents, latent_timestep, block_state.latents
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="processed_mask_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The processed mask to use for the inpainting process.",
|
||||
),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="dtype", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
|
||||
height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
block_state.processed_mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
|
||||
block_state.mask = block_state.mask.unsqueeze(2)
|
||||
block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
|
||||
|
||||
block_state.mask = components.pachifier.pack_latents(block_state.mask)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# Set Timesteps steps
|
||||
|
||||
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len=block_state.latents.shape[1],
|
||||
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
InputParam(name="strength", default=0.9),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="timesteps",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len=block_state.latents.shape[1],
|
||||
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = get_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
strength=block_state.strength,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# other inputs for denoiser
|
||||
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
|
||||
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
]
|
||||
* block_state.batch_size
|
||||
]
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(
|
||||
name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
|
||||
),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# for edit, image size can be different from the target size (height/width)
|
||||
image = (
|
||||
block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
|
||||
)
|
||||
image_width, image_height = image.size
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
),
|
||||
(1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("control_guidance_start", default=0.0),
|
||||
InputParam("control_guidance_end", default=1.0),
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("control_image_latents", required=True),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(
|
||||
block_state.control_guidance_end, list
|
||||
):
|
||||
block_state.control_guidance_start = len(block_state.control_guidance_end) * [
|
||||
block_state.control_guidance_start
|
||||
]
|
||||
elif not isinstance(block_state.control_guidance_end, list) and isinstance(
|
||||
block_state.control_guidance_start, list
|
||||
):
|
||||
block_state.control_guidance_end = len(block_state.control_guidance_start) * [
|
||||
block_state.control_guidance_end
|
||||
]
|
||||
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
|
||||
block_state.control_guidance_end, list
|
||||
):
|
||||
mult = (
|
||||
len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
|
||||
)
|
||||
block_state.control_guidance_start, block_state.control_guidance_end = (
|
||||
mult * [block_state.control_guidance_start],
|
||||
mult * [block_state.control_guidance_end],
|
||||
)
|
||||
|
||||
# controlnet_conditioning_scale (align format)
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
|
||||
block_state.controlnet_conditioning_scale, float
|
||||
):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
|
||||
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
for i in range(len(block_state.timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
|
||||
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import InpaintProcessor, VaeImageProcessor
|
||||
from ...models import AutoencoderKLQwenImage
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the latents to images"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(block_state.latents.device, block_state.latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(block_state.latents.device, block_state.latents.dtype)
|
||||
block_state.latents = block_state.latents / latents_std + latents_mean
|
||||
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "postprocess the generated image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
raise ValueError(f"Invalid output_type: {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.output_type)
|
||||
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
image=block_state.images,
|
||||
output_type=block_state.output_type,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "postprocess the generated image, optional apply the mask overally to the original image.."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_mask_processor",
|
||||
InpaintProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam("mask_overlay_kwargs"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type, mask_overlay_kwargs):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
raise ValueError(f"Invalid output_type: {output_type}")
|
||||
|
||||
if mask_overlay_kwargs and output_type != "pil":
|
||||
raise ValueError("only support output_type 'pil' for mask overlay")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
|
||||
|
||||
if block_state.mask_overlay_kwargs is None:
|
||||
mask_overlay_kwargs = {}
|
||||
else:
|
||||
mask_overlay_kwargs = block_state.mask_overlay_kwargs
|
||||
|
||||
block_state.images = components.image_mask_processor.postprocess(
|
||||
image=block_state.images,
|
||||
**mask_overlay_kwargs,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -0,0 +1,668 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# one timestep
|
||||
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
|
||||
block_state.latent_model_input = block_state.latents
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# one timestep
|
||||
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
|
||||
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that runs the controlnet before the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
# cond_scale for the timestep (controlnet input)
|
||||
if isinstance(block_state.controlnet_keep[i], list):
|
||||
block_state.cond_scale = [
|
||||
c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
|
||||
]
|
||||
else:
|
||||
controlnet_cond_scale = block_state.controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
|
||||
|
||||
# run controlnet for the guidance batch
|
||||
controlnet_block_samples = components.controlnet(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
controlnet_cond=block_state.control_image_latents,
|
||||
conditioning_scale=block_state.cond_scale,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
|
||||
txt_seq_lens=block_state.txt_seq_lens,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that denoise the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", QwenImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
**block_state.additional_cond_kwargs,
|
||||
)[0]
|
||||
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
guider_output = components.guider(guider_state)
|
||||
|
||||
# apply guidance rescale
|
||||
pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
|
||||
pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
|
||||
block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that denoise the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", QwenImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
**block_state.additional_cond_kwargs,
|
||||
)[0]
|
||||
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
guider_output = components.guider(guider_state)
|
||||
|
||||
pred = guider_output.pred[:, : block_state.latents.size(1)]
|
||||
pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
|
||||
|
||||
# apply guidance rescale
|
||||
pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
|
||||
pred_norm = torch.norm(pred, dim=-1, keepdim=True)
|
||||
block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that updates the latents. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"mask",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"initial_noise",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.init_latents_proper = block_state.image_latents
|
||||
if i < len(block_state.timesteps) - 1:
|
||||
block_state.noise_timestep = block_state.timesteps[i + 1]
|
||||
block_state.init_latents_proper = components.scheduler.scale_noise(
|
||||
block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
|
||||
)
|
||||
|
||||
block_state.latents = (
|
||||
1 - block_state.mask
|
||||
) * block_state.init_latents_proper + block_state.mask * block_state.latents
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
block_state.additional_cond_kwargs = {}
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports text2image and image2image tasks for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the inpainting denoising loops
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports text2img/img2img tasks with controlnet for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = [
|
||||
"before_denoiser",
|
||||
"before_denoiser_controlnet",
|
||||
"denoiser",
|
||||
"after_denoiser",
|
||||
"after_denoiser_inpaint",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks with controlnet for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports QwenImage Edit."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage Edit."
|
||||
)
|
||||
@@ -0,0 +1,857 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
|
||||
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_template_encode_start_idx: int = 34,
|
||||
tokenizer_max_length: int = 1024,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = prompt_template_encode
|
||||
drop_idx = prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||
|
||||
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds_edit(
|
||||
text_encoder,
|
||||
processor,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_template_encode_start_idx: int = 64,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = prompt_template_encode
|
||||
drop_idx = prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
|
||||
model_inputs = processor(
|
||||
text=txt,
|
||||
images=image,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
pixel_values=model_inputs.pixel_values,
|
||||
image_grid_thw=model_inputs.image_grid_thw,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
|
||||
def encode_vae_image(
|
||||
image: torch.Tensor,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
generator: torch.Generator,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
latent_channels: int = 16,
|
||||
sample_mode: str = "argmax",
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
|
||||
|
||||
# preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
|
||||
if image.dim() == 4:
|
||||
image = image.unsqueeze(2)
|
||||
elif image.dim() != 5:
|
||||
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(vae.config.latents_std)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
image_latents = (image_latents - latents_mean) / latents_std
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
|
||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
|
||||
This block resizes an input image tensor and exposes the resized result under configurable input and output
|
||||
names. Use this when you need to wire the resize step to different image fields (e.g., "image",
|
||||
"control_image")
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the image field to read from the
|
||||
pipeline state. Defaults to "image".
|
||||
output_name (str, optional): Name of the resized image field to write
|
||||
back to the pipeline state. Defaults to "resized_image".
|
||||
"""
|
||||
if not isinstance(input_name, str) or not isinstance(output_name, str):
|
||||
raise ValueError(
|
||||
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
|
||||
)
|
||||
self._image_input_name = input_name
|
||||
self._resized_image_output_name = output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_resize_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
images = getattr(block_state, self._image_input_name)
|
||||
|
||||
if not is_valid_image_imagelist(images):
|
||||
raise ValueError(f"Images must be image or list of images but are {type(images)}")
|
||||
|
||||
if is_valid_image(images):
|
||||
images = [images]
|
||||
|
||||
image_width, image_height = images[0].size
|
||||
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
|
||||
|
||||
resized_images = [
|
||||
components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
|
||||
for image in images
|
||||
]
|
||||
|
||||
setattr(block_state, self._resized_image_output_name, resized_images)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generate text_embeddings to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
|
||||
ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="prompt_template_encode",
|
||||
default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
),
|
||||
ConfigSpec(name="prompt_template_encode_start_idx", default=34),
|
||||
ConfigSpec(name="tokenizer_max_length", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder attention mask",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings mask",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt, max_sequence_length):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
negative_prompt is not None
|
||||
and not isinstance(negative_prompt, str)
|
||||
and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
|
||||
|
||||
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
components.text_encoder,
|
||||
components.tokenizer,
|
||||
prompt=block_state.prompt,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
tokenizer_max_length=components.config.tokenizer_max_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
components.text_encoder,
|
||||
components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
tokenizer_max_length=components.config.tokenizer_max_length,
|
||||
device=device,
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
|
||||
:, : block_state.max_sequence_length
|
||||
]
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
|
||||
:, : block_state.max_sequence_length
|
||||
]
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
|
||||
ComponentSpec("processor", Qwen2VLProcessor),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="prompt_template_encode",
|
||||
default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
),
|
||||
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="resized_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image prompt to encode, should be resized using resize step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder attention mask",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings mask",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
negative_prompt is not None
|
||||
and not isinstance(negative_prompt, str)
|
||||
and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.prompt, block_state.negative_prompt)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=block_state.prompt,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=negative_prompt,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_mask_processor",
|
||||
InpaintProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("resized_image"),
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
OutputParam(name="processed_mask_image"),
|
||||
OutputParam(
|
||||
name="mask_overlay_kwargs",
|
||||
type_hint=Dict,
|
||||
description="The kwargs for the postprocess step to apply the mask overlay",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.resized_image is None and block_state.image is None:
|
||||
raise ValueError("resized_image and image cannot be None at the same time")
|
||||
|
||||
if block_state.resized_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
else:
|
||||
width, height = block_state.resized_image[0].size
|
||||
image = block_state.resized_image
|
||||
|
||||
block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
|
||||
components.image_mask_processor.preprocess(
|
||||
image=image,
|
||||
mask=block_state.mask_image,
|
||||
height=height,
|
||||
width=width,
|
||||
padding_mask_crop=block_state.padding_mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image"),
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.resized_image is None and block_state.image is None:
|
||||
raise ValueError("resized_image and image cannot be None at the same time")
|
||||
|
||||
if block_state.resized_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
else:
|
||||
width, height = block_state.resized_image[0].size
|
||||
image = block_state.resized_image
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "processed_image",
|
||||
output_name: str = "image_latents",
|
||||
):
|
||||
"""Initialize a VAE encoder step for converting images to latent representations.
|
||||
|
||||
Both the input and output names are configurable so this block can be configured to process to different image
|
||||
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
|
||||
Examples: "processed_image" or "processed_control_image"
|
||||
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
|
||||
Examples: "image_latents" or "control_image_latents"
|
||||
|
||||
Examples:
|
||||
# Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
|
||||
|
||||
# Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
|
||||
input_name="processed_control_image", output_name="control_image_latents"
|
||||
)
|
||||
"""
|
||||
self._image_input_name = input_name
|
||||
self._image_latents_output_name = output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(self._image_input_name, required=True),
|
||||
InputParam("generator"),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
ComponentSpec(
|
||||
"control_image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam("control_image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"control_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the control image",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
|
||||
block_state.control_image = [block_state.control_image]
|
||||
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel):
|
||||
block_state.control_image_latents = []
|
||||
for control_image_ in block_state.control_image:
|
||||
control_image_ = components.control_image_processor.preprocess(
|
||||
image=control_image_,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
control_image_latents_ = encode_vae_image(
|
||||
image=control_image_,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
sample_mode="sample",
|
||||
)
|
||||
block_state.control_image_latents.append(control_image_latents_)
|
||||
|
||||
elif isinstance(controlnet, QwenImageControlNetModel):
|
||||
control_image = components.control_image_processor.preprocess(
|
||||
image=block_state.control_image,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
block_state.control_image_latents = encode_vae_image(
|
||||
image=control_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
sample_mode="sample",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -0,0 +1,431 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import QwenImageMultiControlNetModel
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_images_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_images_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_images_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_images_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
|
||||
by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
|
||||
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
|
||||
vae_scale_factor (int): The scale factor used by the VAE to compress images.
|
||||
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
|
||||
Raises:
|
||||
ValueError: If latents tensor doesn't have 4 or 5 dimensions
|
||||
|
||||
"""
|
||||
# make sure the latents are not packed
|
||||
if latents.ndim != 4 and latents.ndim != 5:
|
||||
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
latent_height, latent_width = latents.shape[-2:]
|
||||
|
||||
height = latent_height * vae_scale_factor
|
||||
width = latent_width * vae_scale_factor
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Text input processing step that standardizes text embeddings for the pipeline.\n"
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
|
||||
|
||||
return summary_section + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(
|
||||
prompt_embeds,
|
||||
prompt_embeds_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
):
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
|
||||
|
||||
if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
|
||||
raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
|
||||
|
||||
if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
|
||||
raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
|
||||
raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
elif (
|
||||
negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
|
||||
):
|
||||
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
prompt_embeds=block_state.prompt_embeds,
|
||||
prompt_embeds_mask=block_state.prompt_embeds_mask,
|
||||
negative_prompt_embeds=block_state.negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
|
||||
)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
|
||||
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to []. Examples: ["processed_mask_image"]
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="control_image_latents", required=True),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if isinstance(components.controlnet, QwenImageMultiControlNetModel):
|
||||
control_image_latents = []
|
||||
# loop through each control_image_latents
|
||||
for i, control_image_latents_ in enumerate(block_state.control_image_latents):
|
||||
# 1. update height/width if not provided
|
||||
height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. pack
|
||||
control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
|
||||
|
||||
# 3. repeat to match the batch size
|
||||
control_image_latents_ = repeat_tensor_to_batch_size(
|
||||
input_name=f"control_image_latents[{i}]",
|
||||
input_tensor=control_image_latents_,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
control_image_latents.append(control_image_latents_)
|
||||
|
||||
block_state.control_image_latents = control_image_latents
|
||||
|
||||
else:
|
||||
# 1. update height/width if not provided
|
||||
height, width = calculate_dimension_from_latents(
|
||||
block_state.control_image_latents, components.vae_scale_factor
|
||||
)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. pack
|
||||
block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
|
||||
|
||||
# 3. repeat to match the batch size
|
||||
block_state.control_image_latents = repeat_tensor_to_batch_size(
|
||||
input_name="control_image_latents",
|
||||
input_tensor=block_state.control_image_latents,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
block_state.control_image_latents = block_state.control_image_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -0,0 +1,841 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageRoPEInputsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
QwenImageEditDenoiseStep,
|
||||
QwenImageEditInpaintDenoiseStep,
|
||||
QwenImageInpaintControlNetDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageControlNetVaeEncoderStep,
|
||||
QwenImageEditResizeDynamicStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# 1. QwenImage
|
||||
|
||||
## 1.1 QwenImage/text2image
|
||||
|
||||
#### QwenImage/decode
|
||||
#### (standard decode step works for most tasks except for inpaint)
|
||||
QwenImageDecodeBlocks = InsertableDict(
|
||||
[
|
||||
("decode", QwenImageDecoderStep()),
|
||||
("postprocess", QwenImageProcessImagesOutputStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageDecodeBlocks.values()
|
||||
block_names = QwenImageDecodeBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
#### QwenImage/text2image presets
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("input", QwenImageTextInputsStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.2 QwenImage/inpaint
|
||||
|
||||
#### QwenImage/inpaint vae encoder
|
||||
QwenImageInpaintVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
(
|
||||
"preprocess",
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintVaeEncoderBlocks.values()
|
||||
block_names = QwenImageInpaintVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
|
||||
" - Resizes the image to the target size, based on `height` and `width`.\n"
|
||||
" - Processes and updates `image` and `mask_image`.\n"
|
||||
" - Creates `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage/inpaint inputs
|
||||
QwenImageInpaintInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
(
|
||||
"additional_inputs",
|
||||
QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintInputBlocks.values()
|
||||
block_names = QwenImageInpaintInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# QwenImage/inpaint prepare latents
|
||||
QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
|
||||
[
|
||||
("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("create_mask_latents", QwenImageCreateMaskLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
|
||||
block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the pachified latents `mask` based on the processedmask image.\n"
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage/inpaint decode
|
||||
QwenImageInpaintDecodeBlocks = InsertableDict(
|
||||
[
|
||||
("decode", QwenImageDecoderStep()),
|
||||
("postprocess", QwenImageInpaintProcessImagesOutputStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintDecodeBlocks.values()
|
||||
block_names = QwenImageInpaintDecodeBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
|
||||
|
||||
|
||||
#### QwenImage/inpaint presets
|
||||
INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageInpaintVaeEncoderStep()),
|
||||
("input", QwenImageInpaintInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageInpaintDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.3 QwenImage/img2img
|
||||
|
||||
#### QwenImage/img2img vae encoder
|
||||
QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", QwenImageProcessImagesInputStep()),
|
||||
("encode", QwenImageVaeEncoderDynamicStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
|
||||
block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage/img2img inputs
|
||||
QwenImageImg2ImgInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageImg2ImgInputBlocks.values()
|
||||
block_names = QwenImageImg2ImgInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
#### QwenImage/img2img presets
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
|
||||
("input", QwenImageImg2ImgInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.4 QwenImage/controlnet
|
||||
|
||||
#### QwenImage/controlnet presets
|
||||
CONTROLNET_BLOCKS = InsertableDict(
|
||||
[
|
||||
("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
|
||||
("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
|
||||
(
|
||||
"controlnet_before_denoise",
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
), # before denoise step (after set_timesteps step)
|
||||
(
|
||||
"controlnet_denoise_loop_before",
|
||||
QwenImageLoopBeforeDenoiserControlNet(),
|
||||
), # controlnet loop step (insert before the denoiseloop_denoiser)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.5 QwenImage/auto encoders
|
||||
|
||||
|
||||
#### for inpaint and img2img tasks
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# for controlnet tasks
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
|
||||
+ " - if `control_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.6 QwenImage/auto inputs
|
||||
|
||||
|
||||
# text2image/inpaint/img2img
|
||||
class QwenImageAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
|
||||
block_names = ["inpaint", "img2img", "text2image"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
|
||||
+ " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# controlnet
|
||||
class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetInputsStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet input step that prepare the control_image_latents input.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - if `control_image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.7 QwenImage/auto before denoise step
|
||||
# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
|
||||
|
||||
# QwenImage/text2image before denoise
|
||||
QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
|
||||
|
||||
|
||||
# QwenImage/inpaint before denoise
|
||||
QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# QwenImage/img2img before denoise
|
||||
QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# auto before_denoise step for text2image, inpaint, img2img tasks
|
||||
class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageInpaintBeforeDenoiseStep,
|
||||
QwenImageImg2ImgBeforeDenoiseStep,
|
||||
QwenImageText2ImageBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["inpaint", "img2img", "text2image"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
|
||||
+ " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto before_denoise step for controlnet tasks
|
||||
class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetBeforeDenoiserStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet before denoise step that prepare the controlnet input.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - if `control_image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.8 QwenImage/auto denoise
|
||||
|
||||
|
||||
# auto denoise step for controlnet tasks: works for all tasks with controlnet
|
||||
class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
|
||||
block_names = ["inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet step during the denoising process. \n"
|
||||
" This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
|
||||
+ " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto denoise step for everything: works for all tasks with or without controlnet
|
||||
class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageControlNetAutoDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
]
|
||||
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["control_image_latents", "mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
|
||||
+ " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
|
||||
+ " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
## 1.9 QwenImage/auto decode
|
||||
# auto decode step for inpaint and text2image tasks
|
||||
|
||||
|
||||
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
|
||||
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
## 1.10 QwenImage/auto block & presets
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("input", QwenImageAutoInputStep()),
|
||||
("controlnet_input", QwenImageOptionalControlNetInputStep()),
|
||||
("before_denoise", QwenImageAutoBeforeDenoiseStep()),
|
||||
("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
|
||||
("denoise", QwenImageAutoDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# 2. QwenImage-Edit
|
||||
|
||||
## 2.1 QwenImage-Edit/edit
|
||||
|
||||
#### QwenImage-Edit/edit vl encoder: take both image and text prompts
|
||||
QwenImageEditVLEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()),
|
||||
("encode", QwenImageEditTextEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditVLEncoderBlocks.values()
|
||||
block_names = QwenImageEditVLEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit vae encoder
|
||||
QwenImageEditVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
|
||||
("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit input
|
||||
QwenImageEditInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInputBlocks.values()
|
||||
block_names = QwenImageEditInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the edit denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
|
||||
" - `image_latents`.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
#### QwenImage/edit presets
|
||||
EDIT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditVaeEncoderStep()),
|
||||
("input", QwenImageEditInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 2.2 QwenImage-Edit/edit inpaint
|
||||
|
||||
#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
|
||||
QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
|
||||
(
|
||||
"preprocess",
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
|
||||
(
|
||||
"encode",
|
||||
QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
|
||||
), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
|
||||
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
|
||||
" - process the resized image and mask image.\n"
|
||||
" - create image latents."
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit inpaint presets
|
||||
EDIT_INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
|
||||
("input", QwenImageInpaintInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditInpaintDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 2.3 QwenImage-Edit/auto encoders
|
||||
|
||||
|
||||
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageEditInpaintVaeEncoderStep,
|
||||
QwenImageEditVaeEncoderStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations. \n"
|
||||
" This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
|
||||
+ " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.4 QwenImage-Edit/auto inputs
|
||||
class QwenImageEditAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step.\n"
|
||||
+ " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
|
||||
+ " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
|
||||
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.5 QwenImage-Edit/auto before denoise
|
||||
# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
|
||||
|
||||
#### QwenImage-Edit/edit before denoise
|
||||
QwenImageEditBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit inpaint before denoise
|
||||
QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
|
||||
|
||||
|
||||
# auto before_denoise step for edit and edit_inpaint tasks
|
||||
class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintBeforeDenoiseStep,
|
||||
QwenImageEditBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
|
||||
+ " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.6 QwenImage-Edit/auto denoise
|
||||
|
||||
|
||||
class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
|
||||
block_names = ["inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
+ "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
|
||||
+ " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.7 QwenImage-Edit/auto blocks & presets
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("input", QwenImageEditAutoInputStep()),
|
||||
("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
|
||||
("denoise", QwenImageEditAutoDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
|
||||
+ "- for edit (img2img) generation, you need to provide `image`\n"
|
||||
+ "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
)
|
||||
|
||||
|
||||
# 3. all block presets supported in QwenImage & QwenImage-Edit
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"edit": EDIT_BLOCKS,
|
||||
"edit_inpaint": EDIT_INPAINT_BLOCKS,
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
"controlnet": CONTROLNET_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"edit_auto": EDIT_AUTO_BLOCKS,
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
class QwenImagePachifier(ConfigMixin):
|
||||
"""
|
||||
A class to pack and unpack latents for QwenImage.
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
if latents.ndim != 4 and latents.ndim != 5:
|
||||
raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
if latents.ndim == 4:
|
||||
latents = latents.unsqueeze(2)
|
||||
|
||||
batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
|
||||
raise ValueError(
|
||||
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
|
||||
)
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
latent_height // patch_size,
|
||||
patch_size,
|
||||
latent_width // patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(
|
||||
0, 2, 4, 1, 3, 5
|
||||
) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
(latent_height // patch_size) * (latent_width // patch_size),
|
||||
num_channels_latents * patch_size * patch_size,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def unpack_latents(self, latents, height, width, vae_scale_factor=8):
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
|
||||
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def is_guidance_distilled(self):
|
||||
is_guidance_distilled = False
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
is_guidance_distilled = self.transformer.config.guidance_embeds
|
||||
return is_guidance_distilled
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage-Edit.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
# YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def is_guidance_distilled(self):
|
||||
is_guidance_distilled = False
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
is_guidance_distilled = self.transformer.config.guidance_embeds
|
||||
return is_guidance_distilled
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
@@ -76,6 +76,7 @@ class StableDiffusionXLModularPipeline(
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
# YiYi TODO: change to num_channels_latents
|
||||
@property
|
||||
def num_channels_unet(self):
|
||||
num_channels_unet = 4
|
||||
|
||||
@@ -285,6 +285,7 @@ else:
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -393,7 +394,9 @@ else:
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
]
|
||||
try:
|
||||
@@ -681,6 +684,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
@@ -714,9 +718,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -651,6 +651,12 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -658,6 +664,12 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -666,6 +678,12 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -673,6 +691,12 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
@property
|
||||
|
||||
@@ -34,6 +34,7 @@ from transformers import (
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_librosa_available,
|
||||
@@ -228,6 +229,12 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
|
||||
@@ -236,6 +243,12 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
|
||||
@@ -91,6 +91,14 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import SanaPipeline
|
||||
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
|
||||
from .stable_diffusion import (
|
||||
@@ -150,6 +158,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -174,6 +184,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
("flux-control", FluxControlImg2ImgPipeline),
|
||||
("flux-kontext", FluxKontextPipeline),
|
||||
("qwenimage", QwenImageImg2ImgPipeline),
|
||||
("qwenimage-edit", QwenImageEditPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -195,6 +207,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
("flux-control", FluxControlInpaintPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
("qwenimage", QwenImageInpaintPipeline),
|
||||
("qwenimage-edit", QwenImageEditInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -19,11 +19,7 @@ from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...models import AutoencoderKL, ChromaTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -508,6 +509,12 @@ class ChromaPipeline(
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -515,6 +522,12 @@ class ChromaPipeline(
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -523,6 +536,12 @@ class ChromaPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -530,6 +549,12 @@ class ChromaPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
@@ -663,11 +688,11 @@ class ChromaPipeline(
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...models import AutoencoderKL, ChromaTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -542,6 +543,12 @@ class ChromaImg2ImgPipeline(
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
||||
deprecate(
|
||||
"enable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
@@ -549,6 +556,12 @@ class ChromaImg2ImgPipeline(
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
||||
deprecate(
|
||||
"disable_vae_slicing",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
@@ -557,6 +570,12 @@ class ChromaImg2ImgPipeline(
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
||||
deprecate(
|
||||
"enable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
@@ -564,6 +583,12 @@ class ChromaImg2ImgPipeline(
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
||||
deprecate(
|
||||
"disable_vae_tiling",
|
||||
"0.40.0",
|
||||
depr_message,
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
@@ -724,12 +749,12 @@ class ChromaImg2ImgPipeline(
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
strength (`float, *optional*, defaults to 0.9):
|
||||
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
||||
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user