Merge branch 'main' into sanitize-pipe-go-tests
This commit is contained in:
@@ -110,8 +110,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -116,8 +116,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -253,9 +254,10 @@ jobs:
|
||||
python -m uv pip install -e [quality,test]
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -U tokenizers
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# python -m uv pip install -U tokenizers
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -132,8 +132,9 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -203,8 +204,9 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -266,7 +268,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -23,13 +23,9 @@
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Load schedulers and models
|
||||
- local: using-diffusers/models
|
||||
title: Models
|
||||
- local: using-diffusers/scheduler_features
|
||||
title: Scheduler features
|
||||
title: Schedulers
|
||||
- local: using-diffusers/other-formats
|
||||
title: Model files and layouts
|
||||
title: Model formats
|
||||
- local: using-diffusers/push_to_hub
|
||||
title: Sharing pipelines and models
|
||||
|
||||
@@ -68,10 +64,14 @@
|
||||
title: Accelerate inference
|
||||
- local: optimization/cache
|
||||
title: Caching
|
||||
- local: optimization/attention_backends
|
||||
title: Attention backends
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compiling and offloading quantized models
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
- title: Community optimizations
|
||||
sections:
|
||||
- local: optimization/pruna
|
||||
@@ -82,6 +82,8 @@
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
- local: optimization/cache_dit
|
||||
title: CacheDiT
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
- local: optimization/xdit
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Parallelism
|
||||
|
||||
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
|
||||
|
||||
## ParallelConfig
|
||||
|
||||
[[autodoc]] ParallelConfig
|
||||
|
||||
## ContextParallelConfig
|
||||
|
||||
[[autodoc]] ContextParallelConfig
|
||||
|
||||
[[autodoc]] hooks.apply_context_parallel
|
||||
@@ -26,6 +26,7 @@ Qwen-Image comes in the following variants:
|
||||
|:----------:|:--------:|
|
||||
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
|
||||
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
|
||||
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -96,6 +97,29 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
|
||||
|
||||
</Tip>
|
||||
|
||||
## Multi-image reference with QwenImageEditPlusPipeline
|
||||
|
||||
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
|
||||
|
||||
```
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt="put the penguin and the cat at a game show called "Qwen Edit Plus Games"",
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
@@ -126,7 +150,15 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImaggeControlNetPipeline
|
||||
## QwenImageControlNetPipeline
|
||||
|
||||
[[autodoc]] QwenImageControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageEditPlusPipeline
|
||||
|
||||
[[autodoc]] QwenImageEditPlusPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Attention backends
|
||||
|
||||
> [!NOTE]
|
||||
> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
|
||||
|
||||
Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
|
||||
|
||||
Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
|
||||
|
||||
| attention family | main feature |
|
||||
|---|---|
|
||||
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
|
||||
| SageAttention | quantizes attention to int8 |
|
||||
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
|
||||
| xFormers | memory-efficient attention with support for various attention kernels |
|
||||
|
||||
This guide will show you how to set and use the different attention backends.
|
||||
|
||||
## set_attention_backend
|
||||
|
||||
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
|
||||
|
||||
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
|
||||
|
||||
> [!NOTE]
|
||||
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_flash_3_hub")
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.reset_attention_backend()
|
||||
```
|
||||
|
||||
## attention_backend context manager
|
||||
|
||||
The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
with attention_backend("_flash_3_hub"):
|
||||
image = pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
|
||||
|
||||
## Available backends
|
||||
|
||||
Refer to the table below for a complete list of available attention backends and their variants.
|
||||
|
||||
<details>
|
||||
<summary>Expand</summary>
|
||||
|
||||
| Backend Name | Family | Description |
|
||||
|--------------|--------|-------------|
|
||||
| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
|
||||
| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
|
||||
| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
|
||||
| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
|
||||
| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
|
||||
| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
|
||||
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
|
||||
|
||||
</details>
|
||||
@@ -0,0 +1,270 @@
|
||||
## CacheDiT
|
||||
|
||||
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
|
||||
|
||||
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
|
||||
|
||||
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="PyPI">
|
||||
|
||||
```bash
|
||||
pip3 install -U cache-dit
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="source">
|
||||
|
||||
```bash
|
||||
pip3 install git+https://github.com/vipshop/cache-dit.git
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Run the command below to view supported DiT pipelines.
|
||||
|
||||
```python
|
||||
>>> import cache_dit
|
||||
>>> cache_dit.supported_pipelines()
|
||||
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
|
||||
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
|
||||
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
|
||||
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
|
||||
```
|
||||
|
||||
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
|
||||
|
||||
|
||||
## Unified Cache API
|
||||
|
||||
CacheDiT works by matching specific input/output patterns as shown below.
|
||||
|
||||

|
||||
|
||||
Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
# Can be any diffusion pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
|
||||
|
||||
# One-line code with default cache options.
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Just call the pipe as normal.
|
||||
output = pipe(...)
|
||||
|
||||
# Disable cache and run original pipe.
|
||||
cache_dit.disable_cache(pipe)
|
||||
```
|
||||
|
||||
## Automatic Block Adapter
|
||||
|
||||
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
|
||||
|
||||
|
||||
```python
|
||||
from cache_dit import ForwardPattern, BlockAdapter
|
||||
|
||||
# Use 🔥BlockAdapter with `auto` mode.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
# Any DiffusionPipeline, Qwen-Image, etc.
|
||||
pipe=pipe, auto=True,
|
||||
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
||||
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
|
||||
# Or, manually setup transformer configurations.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # Qwen-Image, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=pipe.transformer.transformer_blocks,
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
|
||||
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
|
||||
|
||||
```python
|
||||
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
|
||||
# single_transformer_blocks have different forward patterns.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # FLUX.1, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.transformer_blocks,
|
||||
pipe.transformer.single_transformer_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_1,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
|
||||
|
||||
## Patch Functor
|
||||
|
||||
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
|
||||
|
||||

|
||||
|
||||
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
|
||||
|
||||
```python
|
||||
@BlockAdapterRegistry.register("HiDream")
|
||||
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
||||
|
||||
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
||||
return BlockAdapter(
|
||||
pipe=pipe,
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.double_stream_blocks,
|
||||
pipe.transformer.single_stream_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_0,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
# NOTE: Setup your custom patch functor here.
|
||||
patch_functor=HiDreamPatchFunctor(),
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
|
||||
|
||||
```python
|
||||
stats = cache_dit.summary(pipe)
|
||||
```
|
||||
|
||||
```python
|
||||
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
|
||||
|
||||
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
|
||||
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
|
||||
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
|
||||
```
|
||||
|
||||
## DBCache: Dual Block Cache
|
||||
|
||||

|
||||
|
||||
DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
|
||||
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
||||
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
||||
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe_or_adapter = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# Default options, F8B0, 8 warmup steps, and unlimited cached
|
||||
# steps for good balance between performance and precision
|
||||
cache_dit.enable_cache(pipe_or_adapter)
|
||||
|
||||
# Custom options, F8B8, higher precision
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
)
|
||||
```
|
||||
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
|
||||
|
||||
## TaylorSeer Calibrator
|
||||
|
||||
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
||||
|
||||
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
# Basic DBCache w/ FnBn configurations
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
# Then, you can use the TaylorSeer Calibrator to approximate
|
||||
# the values in cached steps, taylorseer_order default is 1.
|
||||
calibrator_config=TaylorSeerCalibratorConfig(
|
||||
taylorseer_order=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
|
||||
|
||||
## Hybrid Cache CFG
|
||||
|
||||
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
...,
|
||||
# For example, set it as True for Wan 2.1, Qwen-Image
|
||||
# and set it as False for FLUX.1, HunyuanVideo, etc.
|
||||
enable_separate_cfg=True,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
|
||||
|
||||
|
||||
```python
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Compile the Transformer module
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
```
|
||||
|
||||
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
|
||||
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit = 96 # default is 8
|
||||
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
||||
```
|
||||
|
||||
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
|
||||
@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# torchao
|
||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
|
||||
[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
|
||||
|
||||
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
|
||||
Make sure Pytorch 2.5+ and torchao are installed with the command below.
|
||||
|
||||
```bash
|
||||
pip install -U torch torchao
|
||||
uv pip install -U torch torchao
|
||||
```
|
||||
|
||||
Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
|
||||
|
||||
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
|
||||
Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
|
||||
|
||||
The example below only quantizes the weights to int8.
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig("int8wo")}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
|
||||
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = AutoModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=dtype,
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
transformer=transformer,
|
||||
torch_dtype=dtype,
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Without quantization: ~31.447 GB
|
||||
# With quantization: ~20.40 GB
|
||||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
```
|
||||
|
||||
TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
|
||||
|
||||
```python
|
||||
# In the above code, add the following after initializing the transformer
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
```
|
||||
|
||||
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
|
||||
Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
|
||||
|
||||
> [!TIP]
|
||||
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
|
||||
|
||||
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
|
||||
## autoquant
|
||||
|
||||
The `TorchAoConfig` class accepts three parameters:
|
||||
- `quant_type`: A string value mentioning one of the quantization types below.
|
||||
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
|
||||
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
|
||||
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torchao.quantization import autoquant
|
||||
|
||||
# Load the pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
transformer = autoquant(pipeline.transformer)
|
||||
```
|
||||
|
||||
## Supported quantization types
|
||||
|
||||
|
||||
@@ -12,17 +12,23 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Distributed inference
|
||||
|
||||
On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
|
||||
Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.
|
||||
|
||||
This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference.
|
||||
This guide will show you how to use [Accelerate](https://huggingface.co/docs/accelerate/index) and [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed inference.
|
||||
|
||||
## 🤗 Accelerate
|
||||
## Accelerate
|
||||
|
||||
🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.
|
||||
Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.
|
||||
|
||||
To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process.
|
||||
Install Accelerate with the following command.
|
||||
|
||||
Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
|
||||
```bash
|
||||
uv pip install accelerate
|
||||
```
|
||||
|
||||
Initialize a [`accelerate.PartialState`] class in a Python file to create a distributed environment. The [`accelerate.PartialState`] class manages process management, device control and distribution, and process coordination.
|
||||
|
||||
Move the [`DiffusionPipeline`] to [`accelerate.PartialState.device`] to assign a GPU to each process.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -30,33 +36,31 @@ from accelerate import PartialState
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.float16
|
||||
)
|
||||
distributed_state = PartialState()
|
||||
pipeline.to(distributed_state.device)
|
||||
```
|
||||
|
||||
Use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
|
||||
|
||||
```py
|
||||
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
|
||||
result = pipeline(prompt).images[0]
|
||||
result.save(f"result_{distributed_state.process_index}.png")
|
||||
```
|
||||
|
||||
Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
|
||||
Call `accelerate launch` to run the script and use the `--num_processes` argument to set the number of GPUs to use.
|
||||
|
||||
```bash
|
||||
accelerate launch run_distributed.py --num_processes=2
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## PyTorch Distributed
|
||||
|
||||
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
|
||||
PyTorch [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) enables [data parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism), which replicates the same model on each device, to process different batches of data in parallel.
|
||||
|
||||
To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:
|
||||
Import `torch.distributed` and `torch.multiprocessing` into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -65,20 +69,20 @@ import torch.multiprocessing as mp
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
sd = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2.
|
||||
Create a function for inference with [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group). This method creates a distributed environment with the backend type, the `rank` of the current process, and the `world_size` or number of processes participating (for example, 2 GPUs would be `world_size=2`).
|
||||
|
||||
Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:
|
||||
Move the pipeline to `rank` and use `get_rank` to assign a GPU to each process. Each process handles a different prompt.
|
||||
|
||||
```py
|
||||
def run_inference(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
sd.to(rank)
|
||||
pipeline.to(rank)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
prompt = "a dog"
|
||||
@@ -89,7 +93,7 @@ def run_inference(rank, world_size):
|
||||
image.save(f"./{'_'.join(prompt)}.png")
|
||||
```
|
||||
|
||||
To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:
|
||||
Use [mp.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to create the number of processes defined in `world_size`.
|
||||
|
||||
```py
|
||||
def main():
|
||||
@@ -101,31 +105,26 @@ if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script:
|
||||
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
||||
|
||||
```bash
|
||||
torchrun run_distributed.py --nproc_per_node=2
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
## device_map
|
||||
|
||||
## Model sharding
|
||||
The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
|
||||
|
||||
Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs.
|
||||
|
||||
Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference.
|
||||
|
||||
Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU.
|
||||
|
||||
> [!TIP]
|
||||
> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory.
|
||||
Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
import torch
|
||||
|
||||
prompt = "a photo of a dog with cat-like look"
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -142,7 +141,7 @@ with torch.no_grad():
|
||||
)
|
||||
```
|
||||
|
||||
Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
|
||||
After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
|
||||
|
||||
```py
|
||||
import gc
|
||||
@@ -162,7 +161,7 @@ del pipeline
|
||||
flush()
|
||||
```
|
||||
|
||||
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
|
||||
Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
@@ -177,9 +176,9 @@ transformer = AutoModel.from_pretrained(
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
|
||||
> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
|
||||
|
||||
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
|
||||
Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
|
||||
|
||||
```py
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
@@ -206,21 +205,12 @@ latents = pipeline(
|
||||
).images
|
||||
```
|
||||
|
||||
Remove the pipeline and transformer from memory as they're no longer needed.
|
||||
|
||||
```py
|
||||
del pipeline.transformer
|
||||
del pipeline
|
||||
|
||||
flush()
|
||||
```
|
||||
|
||||
Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU.
|
||||
Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
import torch
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
@@ -236,4 +226,8 @@ with torch.no_grad():
|
||||
image[0].save("split_transformer.png")
|
||||
```
|
||||
|
||||
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
|
||||
## Resources
|
||||
|
||||
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
|
||||
- For more details, check out Accelerate's [Distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
- The `device_map` argument assign models or an entire pipeline to devices. Refer to the [device placement](../using-diffusers/loading#device-placement) docs for more information.
|
||||
@@ -52,6 +52,9 @@ pipeline = QwenImagePipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Single file format](./other-formats#single-file-format) docs to learn how to load single file models.
|
||||
|
||||
### Local pipelines
|
||||
|
||||
Pipelines can also be run locally. Use [`~huggingface_hub.snapshot_download`] to download a model repository.
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Models
|
||||
|
||||
A diffusion model relies on a few individual models working together to generate an output. These models are responsible for denoising, encoding inputs, and decoding latents into the actual outputs.
|
||||
|
||||
This guide will show you how to load models.
|
||||
|
||||
## Loading a model
|
||||
|
||||
All models are loaded with the [`~ModelMixin.from_pretrained`] method, which downloads and caches the latest model version. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache.
|
||||
|
||||
Pass the `subfolder` argument to [`~ModelMixin.from_pretrained`] to specify where to load the model weights from. Omit the `subfolder` argument if the repository doesn't have a subfolder structure or if you're loading a standalone model.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
```
|
||||
|
||||
## AutoModel
|
||||
|
||||
[`AutoModel`] detects the model class from a `model_index.json` file or a model's `config.json` file. It fetches the correct model class from these files and delegates the actual loading to the model class. [`AutoModel`] is useful for automatic model type detection without needing to know the exact model class beforehand.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
```
|
||||
|
||||
## Model data types
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to load a model with a specific data type. This allows you to load a model in a lower precision to reduce memory usage.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
[nn.Module.to](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to) can also convert to a specific data type on the fly. However, it converts *all* weights to the requested data type unlike `torch_dtype` which respects `_keep_in_fp32_modules`. This argument preserves layers in `torch.float32` for numerical stability and best generation quality (see example [_keep_in_fp32_modules](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
model = model.to(dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Device placement
|
||||
|
||||
Use the `device_map` argument in [`~ModelMixin.from_pretrained`] to place a model on an accelerator like a GPU. It is especially helpful where there are multiple GPUs.
|
||||
|
||||
Diffusers currently provides three options to `device_map` for individual models, `"cuda"`, `"balanced"` and `"auto"`. Refer to the table below to compare the three placement strategies.
|
||||
|
||||
| parameter | description |
|
||||
|---|---|
|
||||
| `"cuda"` | places pipeline on a supported accelerator (CUDA) |
|
||||
| `"balanced"` | evenly distributes pipeline on all GPUs |
|
||||
| `"auto"` | distribute model from fastest device first to slowest |
|
||||
|
||||
Use the `max_memory` argument in [`~ModelMixin.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
max_memory = {0: "16GB", 1: "16GB"}
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
max_memory=max_memory
|
||||
)
|
||||
```
|
||||
|
||||
The `hf_device_map` attribute allows you to access and view the `device_map`.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
# {'': device(type='cuda')}
|
||||
```
|
||||
|
||||
## Saving models
|
||||
|
||||
Save a model with the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
model.save_pretrained("./local/model")
|
||||
```
|
||||
|
||||
For large models, it is helpful to use `max_shard_size` to save a model as multiple shards. A shard can be loaded faster and save memory (refer to the [parallel loading](./loading#parallel-loading) docs for more details), especially if there is more than one GPU.
|
||||
|
||||
```py
|
||||
model.save_pretrained("./local/model", max_shard_size="5GB")
|
||||
```
|
||||
@@ -10,77 +10,183 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Model files and layouts
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Diffusion models are saved in various file types and organized in different layouts. Diffusers stores model weights as safetensors files in *Diffusers-multifolder* layout and it also supports loading files (like safetensors and ckpt files) from a *single-file* layout which is commonly used in the diffusion ecosystem.
|
||||
# Model formats
|
||||
|
||||
Each layout has its own benefits and use cases, and this guide will show you how to load the different files and layouts, and how to convert them.
|
||||
Diffusion models are typically stored in the Diffusers format or single-file format. Model files can be stored in various file types such as safetensors, dduf, or ckpt.
|
||||
|
||||
## Files
|
||||
> [!TIP]
|
||||
> Format refers to whether the weights are stored in a directory structure and file refers to the file type.
|
||||
|
||||
PyTorch model weights are typically saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility as ckpt or bin files. However, pickle is not secure and pickled files may contain malicious code that can be executed. This vulnerability is a serious concern given the popularity of model sharing. To address this security issue, the [Safetensors](https://hf.co/docs/safetensors) library was developed as a secure alternative to pickle, which saves models as safetensors files.
|
||||
This guide will show you how to load pipelines and models from these formats and files.
|
||||
|
||||
## Diffusers format
|
||||
|
||||
The Diffusers format stores each model (UNet, transformer, text encoder) in a separate subfolder. There are several benefits to storing models separately.
|
||||
|
||||
- Faster overall pipeline initialization because you can load the individual model you need or load them all in parallel.
|
||||
- Reduced memory usage because you don't need to load all the pipeline components if you only need one model. [Reuse](./loading#reusing-models-in-multiple-pipelines) a model that is shared between multiple pipelines.
|
||||
- Lower storage requirements because common models shared between multiple pipelines are only downloaded once.
|
||||
- Flexibility to use new or improved models in a pipeline.
|
||||
|
||||
## Single file format
|
||||
|
||||
A single-file format stores *all* the model (UNet, transformer, text encoder) weights in a single file. Benefits of single-file formats include the following.
|
||||
|
||||
- Greater compatibility with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||
- Easier to download and share a single file.
|
||||
|
||||
Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a single file.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
The [`~loaders.FromSingleFileMixin.from_single_file`] method also supports passing new models or schedulers.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
"https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
### Configuration options
|
||||
|
||||
Diffusers format models have a `config.json` file in their repositories with important attributes such as the number of layers and attention heads. The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the appropriate config to use from `config.json`. This may fail in a few rare instances though, in which case, you should use the `config` argument.
|
||||
|
||||
You should also use the `config` argument if the models in a pipeline are different from the original implementation or if it doesn't have the necessary metadata to determine the correct config.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config="segmind/SSD-1B")
|
||||
```
|
||||
|
||||
Diffusers attempts to infer the pipeline components based on the signature types of the pipeline class when using `original_config` with `local_files_only=True`. It won't download the config files from a Hub repository to avoid backward breaking changes when you can't connect to the internet. This method isn't as reliable as providing a path to a local model with the `config` argument and may lead to errors. You should run the pipeline with `local_files_only=False` to download the config files to the local cache to avoid errors.
|
||||
|
||||
Override default configs by passing the arguments directly to [`~loaders.FromSingleFileMixin.from_single_file`]. The examples below demonstrate how to override the configs in a pipeline or model.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLInstructPix2PixPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(
|
||||
ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True
|
||||
)
|
||||
```
|
||||
|
||||
```py
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
|
||||
```
|
||||
|
||||
### Local files
|
||||
|
||||
The [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. For example, any single file checkpoint based on the Stable Diffusion XL base model is configured from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
|
||||
|
||||
If you're working with local files, download the config files with the [`~huggingface_hub.snapshot_download`] method and the model checkpoint with [`~huggingface_hub.hf_hub_download`]. These files are downloaded to your [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can download them to a specific directory with the `local_dir` argument.
|
||||
|
||||
```py
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
|
||||
)
|
||||
```
|
||||
|
||||
### Symlink
|
||||
|
||||
If you're working with a file system that does not support symlinking, download the checkpoint file to a local directory first with the `local_dir` parameter. Using the `local_dir` parameter automatically disables symlinks.
|
||||
|
||||
```py
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
local_dir="my_local_checkpoints",
|
||||
)
|
||||
print("My local checkpoint: ", my_local_checkpoint_path)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
print("My local config: ", my_local_config_path)
|
||||
```
|
||||
|
||||
Pass these paths to [`~loaders.FromSingleFileMixin.from_single_file`].
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
|
||||
)
|
||||
```
|
||||
|
||||
## File types
|
||||
|
||||
Models can be stored in several file types. Safetensors is the most common file type but you may encounter other file types on the Hub or diffusion community.
|
||||
|
||||
### safetensors
|
||||
|
||||
> [!TIP]
|
||||
> Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
|
||||
[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file type for securely storing and loading tensors. It restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and generally loads faster.
|
||||
|
||||
[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file format for securely storing and loading tensors. Safetensors restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and has generally faster loading speeds.
|
||||
Diffusers loads safetensors file by default (a required dependency) if they are available and the Safetensors library is installed.
|
||||
|
||||
Make sure you have the [Safetensors](https://hf.co/docs/safetensors) library installed.
|
||||
|
||||
```py
|
||||
!pip install safetensors
|
||||
```
|
||||
|
||||
Safetensors stores weights in a safetensors file. Diffusers loads safetensors files by default if they're available and the Safetensors library is installed. There are two ways safetensors files can be organized:
|
||||
|
||||
1. Diffusers-multifolder layout: there may be several separate safetensors files, one for each pipeline component (text encoder, UNet, VAE), organized in subfolders (check out the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main) repository as an example)
|
||||
2. single-file layout: all the model weights may be saved in a single file (check out the [WarriorMama777/OrangeMixs](https://hf.co/WarriorMama777/OrangeMixs/tree/main/Models/AbyssOrangeMix) repository as an example)
|
||||
|
||||
<hfoptions id="safetensors">
|
||||
<hfoption id="multifolder">
|
||||
|
||||
Use the [`~DiffusionPipeline.from_pretrained`] method to load a model with safetensors files stored in multiple folders.
|
||||
Use [`~DiffusionPipeline.from_pretrained`] or [`~loaders.FromSingleFileMixin.from_single_file`] to load safetensor files.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
use_safetensors=True
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch.dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
pipeline = DiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="single file">
|
||||
If you're using a checkpoint trained with a Diffusers training script, metadata such as the LoRA configuration, is automatically saved. When the file is loaded, the metadata is parsed to correctly configure the LoRA and avoid missing or incorrect LoRA configs. Inspect the metadata of a safetensors file by clicking on the <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/safetensors/logo.png" alt="safetensors logo" style="vertical-align: middle; display: inline-block; max-height: 0.8em; max-width: 0.8em; margin: 0; padding: 0; line-height: 1;"> logo next to the file on the Hub.
|
||||
|
||||
Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to load a model with all the weights stored in a single safetensors file.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### LoRAs
|
||||
|
||||
[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
|
||||
|
||||
The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/safetensors_lora.png"/>
|
||||
</div>
|
||||
|
||||
For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
|
||||
Save the metadata for LoRAs that aren't trained with Diffusers with either `transformer_lora_adapter_metadata` or `unet_lora_adapter_metadata` depending on your model. For the text encoder, use the `text_encoder_lora_adapter_metadata` and `text_encoder_2_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`]. This is only supported for safetensors files.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -91,423 +197,88 @@ pipeline = FluxPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
|
||||
pipeline.save_lora_weights(
|
||||
transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
|
||||
text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
|
||||
text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8},
|
||||
text_encoder_2_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
|
||||
)
|
||||
```
|
||||
|
||||
### ckpt
|
||||
|
||||
> [!WARNING]
|
||||
> Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files instead where possible, or convert the weights to safetensors files.
|
||||
Older model weights are commonly saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility in a ckpt file.
|
||||
|
||||
PyTorch's [torch.save](https://pytorch.org/docs/stable/generated/torch.save.html) function uses Python's [pickle](https://docs.python.org/3/library/pickle.html) utility to serialize and save models. These files are saved as a ckpt file and they contain the entire model's weights.
|
||||
Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files or convert the weights to safetensors files.
|
||||
|
||||
Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to directly load a ckpt file.
|
||||
Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a ckpt file.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
pipeline = DiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt"
|
||||
)
|
||||
```
|
||||
|
||||
## Storage layout
|
||||
### dduf
|
||||
|
||||
There are two ways model files are organized, either in a Diffusers-multifolder layout or in a single-file layout. The Diffusers-multifolder layout is the default, and each component file (text encoder, UNet, VAE) is stored in a separate subfolder. Diffusers also supports loading models from a single-file layout where all the components are bundled together.
|
||||
> [!TIP]
|
||||
> DDUF is an experimental file type and the API may change. Refer to the DDUF [docs](https://huggingface.co/docs/hub/dduf) to learn more.
|
||||
|
||||
### Diffusers-multifolder
|
||||
DDUF is a file type designed to unify different diffusion model distribution methods and weight-saving formats. It is a standardized and flexible method to package all components of a diffusion model into a single file, providing a balance between the Diffusers and single-file formats.
|
||||
|
||||
The Diffusers-multifolder layout is the default storage layout for Diffusers. Each component's (text encoder, UNet, VAE) weights are stored in a separate subfolder. The weights can be stored as safetensors or ckpt files.
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/multifolder-layout.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">multifolder layout</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/multifolder-unet.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">UNet subfolder</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
To load from Diffusers-multifolder layout, use the [`~DiffusionPipeline.from_pretrained`] method.
|
||||
Use the `dduf_file` argument in [`~DiffusionPipeline.from_pretrained`] to load a DDUF file. You can also load quantized dduf files as long as they are stored in the Diffusers format.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
"DDUF/FLUX.1-dev-DDUF",
|
||||
dduf_file="FLUX.1-dev.dduf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
Benefits of using the Diffusers-multifolder layout include:
|
||||
|
||||
1. Faster to load each component file individually or in parallel.
|
||||
2. Reduced memory usage because you only load the components you need. For example, models like [SDXL Turbo](https://hf.co/stabilityai/sdxl-turbo), [SDXL Lightning](https://hf.co/ByteDance/SDXL-Lightning), and [Hyper-SD](https://hf.co/ByteDance/Hyper-SD) have the same components except for the UNet. You can reuse their shared components with the [`~DiffusionPipeline.from_pipe`] method without consuming any additional memory (take a look at the [Reuse a pipeline](./loading#reuse-a-pipeline) guide) and only load the UNet. This way, you don't need to download redundant components and unnecessarily use more memory.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
||||
|
||||
# download one model
|
||||
sdxl_pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
|
||||
# switch UNet for another model
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/sdxl-turbo",
|
||||
subfolder="unet",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True
|
||||
)
|
||||
# reuse all the same components in new model except for the UNet
|
||||
turbo_pipeline = StableDiffusionXLPipeline.from_pipe(
|
||||
sdxl_pipeline, unet=unet,
|
||||
).to("cuda")
|
||||
turbo_pipeline.scheduler = EulerDiscreteScheduler.from_config(
|
||||
turbo_pipeline.scheduler.config,
|
||||
timestep_spacing="trailing"
|
||||
)
|
||||
image = turbo_pipeline(
|
||||
"an astronaut riding a unicorn on mars",
|
||||
num_inference_steps=1,
|
||||
guidance_scale=0.0,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
3. Reduced storage requirements because if a component, such as the SDXL [VAE](https://hf.co/madebyollin/sdxl-vae-fp16-fix), is shared across multiple models, you only need to download and store a single copy of it instead of downloading and storing it multiple times. For 10 SDXL models, this can save ~3.5GB of storage. The storage savings is even greater for newer models like PixArt Sigma, where the [text encoder](https://hf.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/tree/main/text_encoder) alone is ~19GB!
|
||||
4. Flexibility to replace a component in the model with a newer or better version.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
vae=vae,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
5. More visibility and information about a model's components, which are stored in a [config.json](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json) file in each component subfolder.
|
||||
|
||||
### Single-file
|
||||
|
||||
The single-file layout stores all the model weights in a single file. All the model components (text encoder, UNet, VAE) weights are kept together instead of separately in subfolders. This can be a safetensors or ckpt file.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/single-file-layout.png"/>
|
||||
</div>
|
||||
|
||||
To load from a single-file layout, use the [`~loaders.FromSingleFileMixin.from_single_file`] method.
|
||||
To save a pipeline as a dduf file, use the [`~huggingface_hub.export_folder_as_dduf`] utility.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Benefits of using a single-file layout include:
|
||||
|
||||
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
|
||||
2. Easier to manage (download and share) a single file.
|
||||
|
||||
### DDUF
|
||||
|
||||
> [!WARNING]
|
||||
> DDUF is an experimental file format and APIs related to it can change in the future.
|
||||
|
||||
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
|
||||
|
||||
Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
|
||||
|
||||
Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
image = pipe(
|
||||
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
|
||||
|
||||
```py
|
||||
from huggingface_hub import export_folder_as_dduf
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
save_folder = "flux-dev"
|
||||
pipe.save_pretrained("flux-dev")
|
||||
pipeline.save_pretrained("flux-dev")
|
||||
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
|
||||
## Converting formats and files
|
||||
|
||||
## Convert layout and files
|
||||
Diffusers provides scripts and methods to convert format and files to enable broader support across the diffusion ecosystem.
|
||||
|
||||
Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
|
||||
Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) folder to find a conversion script. Scripts with `"to_diffusers` appended at the end converts a model to the Diffusers format. Each script has a specific set of arguments for configuring the conversion. Make sure you check what arguments are available.
|
||||
|
||||
Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) collection to find a script that fits your conversion needs.
|
||||
|
||||
> [!TIP]
|
||||
> Scripts that have "`to_diffusers`" appended at the end mean they convert a model to the Diffusers-multifolder layout. Each script has their own specific set of arguments for configuring the conversion, so make sure you check what arguments are available!
|
||||
|
||||
For example, to convert a Stable Diffusion XL model stored in Diffusers-multifolder layout to a single-file layout, run the [convert_diffusers_to_original_sdxl.py](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_sdxl.py) script. Provide the path to the model to convert, and the path to save the converted model to. You can optionally specify whether you want to save the model as a safetensors file and whether to save the model in half-precision.
|
||||
The example below converts a model stored in Diffusers format to a single-file format. Provide the path to the model to convert and where to save the converted model. You can optionally specify what file type and data type to save the model as.
|
||||
|
||||
```bash
|
||||
python convert_diffusers_to_original_sdxl.py --model_path path/to/model/to/convert --checkpoint_path path/to/save/model/to --use_safetensors
|
||||
```
|
||||
|
||||
You can also save a model to Diffusers-multifolder layout with the [`~DiffusionPipeline.save_pretrained`] method. This creates a directory for you if it doesn't already exist, and it also saves the files as a safetensors file by default.
|
||||
The [`~DiffusionPipeline.save_pretrained`] method also saves a model in Diffusers format and takes care of creating subfolders for each model. It saves the files as safetensor files by default.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
pipeline = DiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
|
||||
)
|
||||
pipeline.save_pretrained()
|
||||
```
|
||||
|
||||
Lastly, there are also Spaces, such as [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) and [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers), that provide a more user-friendly interface for converting models to Diffusers-multifolder layout. This is the easiest and most convenient option for converting layouts, and it'll open a PR on your model repository with the converted files. However, this option is not as reliable as running a script, and the Space may fail for more complicated models.
|
||||
Finally, you can use a Space like [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) or [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers) to convert models to the Diffusers format. It'll open a PR on your model repository with the converted files. This is the easiest way to convert a model, but it may fail for more complicated models. Using a conversion script is more reliable.
|
||||
|
||||
## Single-file layout usage
|
||||
## Resources
|
||||
|
||||
Now that you're familiar with the differences between the Diffusers-multifolder and single-file layout, this section shows you how to load models and pipeline components, customize configuration options for loading, and load local files with the [`~loaders.FromSingleFileMixin.from_single_file`] method.
|
||||
- Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
|
||||
|
||||
### Load a pipeline or model
|
||||
|
||||
Pass the file path of the pipeline or model to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load it.
|
||||
|
||||
<hfoptions id="pipeline-model">
|
||||
<hfoption id="pipeline">
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="model">
|
||||
|
||||
```py
|
||||
from diffusers import StableCascadeUNet
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Customize components in the pipeline by passing them directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. For example, you can use a different scheduler in a pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
scheduler = DDIMScheduler()
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Or you could use a ControlNet model in the pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipeline = StableDiffusionControlNetPipeline.from_single_file(ckpt_path, controlnet=controlnet)
|
||||
```
|
||||
|
||||
### Customize configuration options
|
||||
|
||||
Models have a configuration file that define their attributes like the number of inputs in a UNet. Pipelines configuration options are available in the pipeline's class. For example, if you look at the [`StableDiffusionXLInstructPix2PixPipeline`] class, there is an option to scale the image latents with the `is_cosxl_edit` parameter.
|
||||
|
||||
These configuration files can be found in the models Hub repository or another location from which the configuration file originated (for example, a GitHub repository or locally on your device).
|
||||
|
||||
<hfoptions id="config-file">
|
||||
<hfoption id="Hub configuration file">
|
||||
|
||||
> [!TIP]
|
||||
> The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically maps the checkpoint to the appropriate model repository, but there are cases where it is useful to use the `config` parameter. For example, if the model components in the checkpoint are different from the original checkpoint or if a checkpoint doesn't have the necessary metadata to correctly determine the configuration to use for the pipeline.
|
||||
|
||||
The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the configuration to use from the configuration file in the model repository. You could also explicitly specify the configuration to use by providing the repository id to the `config` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
|
||||
repo_id = "segmind/SSD-1B"
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
|
||||
```
|
||||
|
||||
The model loads the configuration file for the [UNet](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json), [VAE](https://huggingface.co/segmind/SSD-1B/blob/main/vae/config.json), and [text encoder](https://huggingface.co/segmind/SSD-1B/blob/main/text_encoder/config.json) from their respective subfolders in the repository.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="original configuration file">
|
||||
|
||||
The [`~loaders.FromSingleFileMixin.from_single_file`] method can also load the original configuration file of a pipeline that is stored elsewhere. Pass a local path or URL of the original configuration file to the `original_config` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Diffusers attempts to infer the pipeline components based on the type signatures of the pipeline class when you use `original_config` with `local_files_only=True`, instead of fetching the configuration files from the model repository on the Hub. This prevents backward breaking changes in code that can't connect to the internet to fetch the necessary configuration files.
|
||||
>
|
||||
> This is not as reliable as providing a path to a local model repository with the `config` parameter, and might lead to errors during pipeline configuration. To avoid errors, run the pipeline with `local_files_only=False` once to download the appropriate pipeline configuration files to the local cache.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
While the configuration files specify the pipeline or models default parameters, you can override them by providing the parameters directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. Any parameter supported by the model or pipeline class can be configured in this way.
|
||||
|
||||
<hfoptions id="override">
|
||||
<hfoption id="pipeline">
|
||||
|
||||
For example, to scale the image latents in [`StableDiffusionXLInstructPix2PixPipeline`] pass the `is_cosxl_edit` parameter.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLInstructPix2PixPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="model">
|
||||
|
||||
For example, to upcast the attention dimensions in a [`UNet2DConditionModel`] pass the `upcast_attention` parameter.
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Local files
|
||||
|
||||
In Diffusers>=v0.28.0, the [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. The inferred model type is used to determine the appropriate model repository on the Hugging Face Hub to configure the model or pipeline.
|
||||
|
||||
For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repository to configure the pipeline.
|
||||
|
||||
But if you're working in an environment with restricted internet access, you should download the configuration files with the [`~huggingface_hub.snapshot_download`] function, and the model checkpoint with the [`~huggingface_hub.hf_hub_download`] function. By default, these files are downloaded to the Hugging Face Hub [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can specify a preferred directory to download the files to with the `local_dir` parameter.
|
||||
|
||||
Pass the configuration and checkpoint paths to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load locally.
|
||||
|
||||
<hfoptions id="local">
|
||||
<hfoption id="Hub cache directory">
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="specific local directory">
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
local_dir="my_local_checkpoints"
|
||||
)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir="my_local_config"
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### Local files without symlink
|
||||
|
||||
> [!TIP]
|
||||
> In huggingface_hub>=v0.23.0, the `local_dir_use_symlinks` argument isn't necessary for the [`~huggingface_hub.hf_hub_download`] and [`~huggingface_hub.snapshot_download`] functions.
|
||||
|
||||
The [`~loaders.FromSingleFileMixin.from_single_file`] method relies on the [huggingface_hub](https://hf.co/docs/huggingface_hub/index) caching mechanism to fetch and store checkpoints and configuration files for models and pipelines. If you're working with a file system that does not support symlinking, you should download the checkpoint file to a local directory first, and disable symlinking with the `local_dir_use_symlink=False` parameter in the [`~huggingface_hub.hf_hub_download`] function and [`~huggingface_hub.snapshot_download`] functions.
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
local_dir="my_local_checkpoints",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
print("My local checkpoint: ", my_local_checkpoint_path)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("My local config: ", my_local_config_path)
|
||||
```
|
||||
|
||||
Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
|
||||
|
||||
```python
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
```
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Scheduler features
|
||||
|
||||
The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
|
||||
|
||||
This guide will demonstrate how to use these features to improve inference quality.
|
||||
|
||||
> [!TIP]
|
||||
> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
|
||||
|
||||
For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
|
||||
|
||||
```py
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
```
|
||||
|
||||
You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
|
||||
|
||||
- `leading` creates evenly spaced steps
|
||||
- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
|
||||
- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
|
||||
|
||||
It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Sigmas
|
||||
|
||||
The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
|
||||
|
||||
For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "anthropomorphic capybara wearing a suit and working with a computer"
|
||||
generator = torch.Generator(device='cuda').manual_seed(123)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=10,
|
||||
sigmas=sigmas,
|
||||
generator=generator
|
||||
).images[0]
|
||||
```
|
||||
|
||||
When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
|
||||
|
||||
```py
|
||||
print(f" timesteps: {pipe.scheduler.timesteps}")
|
||||
"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
|
||||
```
|
||||
|
||||
### Karras sigmas
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
|
||||
>
|
||||
> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
|
||||
|
||||
Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
|
||||
|
||||
Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Rescale noise schedule
|
||||
|
||||
In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
|
||||
|
||||
> [!TIP]
|
||||
> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
|
||||
>
|
||||
> ```bash
|
||||
> --prediction_type="v_prediction"
|
||||
> ```
|
||||
|
||||
For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
|
||||
|
||||
* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
|
||||
* `timestep_spacing="trailing"` to start sampling from the last timestep
|
||||
|
||||
Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
|
||||
generator = torch.Generator(device="cpu").manual_seed(23)
|
||||
image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
@@ -10,200 +10,273 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Load schedulers and models
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
|
||||
# Schedulers
|
||||
|
||||
This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
|
||||
A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
|
||||
|
||||
Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
|
||||
|
||||
This guide will show you how to load and customize schedulers.
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
|
||||
|
||||
```py
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
PNDMScheduler {
|
||||
"_class_name": "PNDMScheduler",
|
||||
"_diffusers_version": "0.21.4",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": null
|
||||
}
|
||||
```
|
||||
|
||||
## Load a scheduler
|
||||
|
||||
Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
|
||||
|
||||
For example, to load the [`DDIMScheduler`]:
|
||||
|
||||
```py
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
```
|
||||
|
||||
Then you can pass the newly loaded scheduler to the pipeline.
|
||||
|
||||
```python
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
|
||||
|
||||
Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
```
|
||||
|
||||
To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
|
||||
|
||||
<hfoptions id="schedulers">
|
||||
<hfoption id="LMSDiscreteScheduler">
|
||||
|
||||
[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
|
||||
|
||||
```py
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerDiscreteScheduler">
|
||||
|
||||
[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerAncestralDiscreteScheduler">
|
||||
|
||||
[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPMSolverMultistepScheduler">
|
||||
|
||||
[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
|
||||
Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
scheduler=dpm,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
```
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
|
||||
|
||||
> [!TIP]
|
||||
> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
|
||||
|
||||
Import the schedule and pass it to the `timesteps` argument in the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Rescaling schedules
|
||||
|
||||
Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
|
||||
|
||||
> [!TIP]
|
||||
> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
|
||||
|
||||
To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
|
||||
|
||||
- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
|
||||
- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
```
|
||||
|
||||
Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
|
||||
|
||||
```py
|
||||
prompt = """
|
||||
cinematic photo of a snowy mountain at night with the northern lights aurora borealis
|
||||
overhead, 35mm photograph, film, professional, 4k, highly detailed
|
||||
"""
|
||||
image = pipeline(prompt, guidance_rescale=0.7).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
|
||||
|
||||
| spacing strategy | spacing calculation | example timesteps |
|
||||
|---|---|---|
|
||||
| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
|
||||
| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
|
||||
| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
|
||||
|
||||
Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
|
||||
|
||||
> [!TIP]
|
||||
> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">LMSDiscreteScheduler</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerAncestralDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">DPMSolverMultistepScheduler</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
|
||||
## Sigmas
|
||||
|
||||
## Models
|
||||
Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
|
||||
|
||||
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
|
||||
> [!TIP]
|
||||
> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
|
||||
```
|
||||
|
||||
They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
|
||||
```
|
||||
|
||||
To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
|
||||
)
|
||||
unet.save_pretrained("./local-unet", variant="non_ema")
|
||||
```
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
|
||||
Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
|
||||
### Karras sigmas
|
||||
|
||||
[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
|
||||
|
||||
Set `use_karras_sigmas=True` in the scheduler to enable it.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config,
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
|
||||
|
||||
## Choosing a scheduler
|
||||
|
||||
It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
|
||||
|
||||
- DPM++ 2M SDE Karras is generally a good all-purpose option.
|
||||
- [`TCDScheduler`] works well for distilled models.
|
||||
- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
|
||||
- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
|
||||
- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
|
||||
|
||||
## Resources
|
||||
|
||||
- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
|
||||
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetModels:
|
||||
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
||||
SD3_5: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TextToImagePipelineSD3:
|
||||
def __init__(self, model_path: str | None = None):
|
||||
self.model_path = model_path or os.getenv("MODEL_PATH")
|
||||
self.pipeline: StableDiffusion3Pipeline | None = None
|
||||
self.device: str | None = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
class ModelPipelineInitializer:
|
||||
def __init__(self, model: str = "", type_models: str = "t2im"):
|
||||
self.model = model
|
||||
self.type_models = type_models
|
||||
self.pipeline = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
self.model_type = None
|
||||
|
||||
def initialize_pipeline(self):
|
||||
if not self.model:
|
||||
raise ValueError("Model name not provided")
|
||||
|
||||
# Check if model exists in PresetModels
|
||||
preset_models = PresetModels()
|
||||
|
||||
# Determine which model type we're dealing with
|
||||
if self.model in preset_models.SD3:
|
||||
self.model_type = "SD3"
|
||||
elif self.model in preset_models.SD3_5:
|
||||
self.model_type = "SD3_5"
|
||||
|
||||
# Create appropriate pipeline based on model type and type_models
|
||||
if self.type_models == "t2im":
|
||||
if self.model_type in ["SD3", "SD3_5"]:
|
||||
self.pipeline = TextToImagePipelineSD3(self.model)
|
||||
else:
|
||||
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
||||
elif self.type_models == "t2v":
|
||||
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
||||
|
||||
return self.pipeline
|
||||
@@ -0,0 +1,171 @@
|
||||
# Asynchronous server and parallel execution of models
|
||||
|
||||
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
|
||||
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
|
||||
|
||||
## ⚠️ IMPORTANT
|
||||
|
||||
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
|
||||
|
||||
## Necessary components
|
||||
|
||||
All the components needed to create the inference server are in the current directory:
|
||||
|
||||
```
|
||||
server-async/
|
||||
├── utils/
|
||||
├─────── __init__.py
|
||||
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
|
||||
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
|
||||
├─────── utils.py # Image/video saving utilities and service configuration
|
||||
├── Pipelines.py # pipeline loader classes (SD3)
|
||||
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
|
||||
├── test.py # Client test script for inference requests
|
||||
├── requirements.txt # Dependencies
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## What `diffusers-async` adds / Why we needed it
|
||||
|
||||
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
|
||||
|
||||
`diffusers-async` / this example addresses that by:
|
||||
|
||||
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
|
||||
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
|
||||
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
|
||||
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
|
||||
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
|
||||
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
|
||||
|
||||
## How the server works (high-level flow)
|
||||
|
||||
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
|
||||
2. On each HTTP inference request:
|
||||
|
||||
* The server uses `RequestScopedPipeline.generate(...)` which:
|
||||
|
||||
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
|
||||
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
|
||||
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
|
||||
* sets `local_pipe.scheduler = local_scheduler` (if possible),
|
||||
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
|
||||
* wraps tokenizers with thread-safe locks to prevent race conditions,
|
||||
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
|
||||
* calls the pipeline on the local view (`local_pipe(...)`).
|
||||
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
|
||||
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
|
||||
|
||||
## How to set up and run the server
|
||||
|
||||
### 1) Install dependencies
|
||||
|
||||
Recommended: create a virtualenv / conda environment.
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2) Start the server
|
||||
|
||||
Using the `serverasync.py` file that already has everything you need:
|
||||
|
||||
```bash
|
||||
python serverasync.py
|
||||
```
|
||||
|
||||
The server will start on `http://localhost:8500` by default with the following features:
|
||||
- FastAPI application with async lifespan management
|
||||
- Automatic model loading and pipeline initialization
|
||||
- Request counting and active inference tracking
|
||||
- Memory cleanup after each inference
|
||||
- CORS middleware for cross-origin requests
|
||||
|
||||
### 3) Test the server
|
||||
|
||||
Use the included test script:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
Or send a manual request:
|
||||
|
||||
`POST /api/diffusers/inference` with JSON body:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "A futuristic cityscape, vibrant colors",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
|
||||
```json
|
||||
{
|
||||
"response": ["http://localhost:8500/images/img123.png"]
|
||||
}
|
||||
```
|
||||
|
||||
### 4) Server endpoints
|
||||
|
||||
- `GET /` - Welcome message
|
||||
- `POST /api/diffusers/inference` - Main inference endpoint
|
||||
- `GET /images/{filename}` - Serve generated images
|
||||
- `GET /api/status` - Server status and memory info
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### RequestScopedPipeline Parameters
|
||||
|
||||
```python
|
||||
RequestScopedPipeline(
|
||||
pipeline, # Base pipeline to wrap
|
||||
mutable_attrs=None, # Custom list of attributes to clone
|
||||
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
|
||||
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
|
||||
tokenizer_lock=None, # Custom threading lock for tokenizers
|
||||
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
|
||||
)
|
||||
```
|
||||
|
||||
### BaseAsyncScheduler Features
|
||||
|
||||
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
|
||||
* `clone_for_request()` method for safe per-request scheduler cloning
|
||||
* Enhanced debugging with `__repr__` and `__str__` methods
|
||||
* Full compatibility with existing scheduler APIs
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = 'stabilityai/stable-diffusion-3.5-medium'
|
||||
type_models: str = 't2im'
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8500
|
||||
```
|
||||
|
||||
## Troubleshooting (quick)
|
||||
|
||||
* `Already borrowed` — previously a Rust tokenizer concurrency error.
|
||||
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
|
||||
|
||||
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
|
||||
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
|
||||
|
||||
* Scheduler issues:
|
||||
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
|
||||
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
|
||||
|
||||
* Memory issues with large tensors:
|
||||
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
|
||||
|
||||
* Automatic tokenizer detection:
|
||||
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
|
||||
@@ -0,0 +1,10 @@
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
sentencepiece
|
||||
fastapi
|
||||
uvicorn
|
||||
ftfy
|
||||
accelerate
|
||||
xformers
|
||||
protobuf
|
||||
@@ -0,0 +1,230 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from Pipelines import ModelPipelineInitializer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import RequestScopedPipeline, Utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
||||
type_models: str = "t2im"
|
||||
constructor_pipeline: Optional[Type] = None
|
||||
custom_pipeline: Optional[Type] = None
|
||||
components: Optional[Dict[str, Any]] = None
|
||||
torch_dtype: Optional[torch.dtype] = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8500
|
||||
|
||||
|
||||
server_config = ServerConfigModels()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
app.state.logger = logging.getLogger("diffusers-server")
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
|
||||
app.state.total_requests = 0
|
||||
app.state.active_inferences = 0
|
||||
app.state.metrics_lock = asyncio.Lock()
|
||||
app.state.metrics_task = None
|
||||
|
||||
app.state.utils_app = Utils(
|
||||
host=server_config.host,
|
||||
port=server_config.port,
|
||||
)
|
||||
|
||||
async def metrics_loop():
|
||||
try:
|
||||
while True:
|
||||
async with app.state.metrics_lock:
|
||||
total = app.state.total_requests
|
||||
active = app.state.active_inferences
|
||||
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
app.state.logger.info("Metrics loop cancelled")
|
||||
raise
|
||||
|
||||
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task = app.state.metrics_task
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
||||
if callable(stop_fn):
|
||||
await run_in_threadpool(stop_fn)
|
||||
except Exception as e:
|
||||
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
||||
|
||||
app.state.logger.info("Lifespan shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger = logging.getLogger("DiffusersServer.Pipelines")
|
||||
|
||||
|
||||
initializer = ModelPipelineInitializer(
|
||||
model=server_config.model,
|
||||
type_models=server_config.type_models,
|
||||
)
|
||||
model_pipeline = initializer.initialize_pipeline()
|
||||
model_pipeline.start()
|
||||
|
||||
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
||||
pipeline_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
||||
|
||||
app.state.MODEL_INITIALIZER = initializer
|
||||
app.state.MODEL_PIPELINE = model_pipeline
|
||||
app.state.REQUEST_PIPE = request_pipe
|
||||
app.state.PIPELINE_LOCK = pipeline_lock
|
||||
|
||||
|
||||
class JSONBodyQueryAPI(BaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
num_inference_steps: int = 28
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def count_requests_middleware(request: Request, call_next):
|
||||
async with app.state.metrics_lock:
|
||||
app.state.total_requests += 1
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Diffusers Server"}
|
||||
|
||||
|
||||
@app.post("/api/diffusers/inference")
|
||||
async def api(json: JSONBodyQueryAPI):
|
||||
prompt = json.prompt
|
||||
negative_prompt = json.negative_prompt or ""
|
||||
num_steps = json.num_inference_steps
|
||||
num_images_per_prompt = json.num_images_per_prompt
|
||||
|
||||
wrapper = app.state.MODEL_PIPELINE
|
||||
initializer = app.state.MODEL_INITIALIZER
|
||||
|
||||
utils_app = app.state.utils_app
|
||||
|
||||
if not wrapper or not wrapper.pipeline:
|
||||
raise HTTPException(500, "Model not initialized correctly")
|
||||
if not prompt.strip():
|
||||
raise HTTPException(400, "No prompt provided")
|
||||
|
||||
def make_generator():
|
||||
g = torch.Generator(device=initializer.device)
|
||||
return g.manual_seed(random.randint(0, 10_000_000))
|
||||
|
||||
req_pipe = app.state.REQUEST_PIPE
|
||||
|
||||
def infer():
|
||||
gen = make_generator()
|
||||
return req_pipe.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=gen,
|
||||
num_inference_steps=num_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=initializer.device,
|
||||
output_type="pil",
|
||||
)
|
||||
|
||||
try:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences += 1
|
||||
|
||||
output = await run_in_threadpool(infer)
|
||||
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
|
||||
urls = [utils_app.save_image(img) for img in output.images]
|
||||
return {"response": urls}
|
||||
|
||||
except Exception as e:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise HTTPException(500, f"Error in processing: {e}")
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@app.get("/images/{filename}")
|
||||
async def serve_image(filename: str):
|
||||
utils_app = app.state.utils_app
|
||||
file_path = os.path.join(utils_app.image_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
memory_info = {}
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
||||
memory_info = {
|
||||
"memory_allocated_gb": round(memory_allocated, 2),
|
||||
"memory_reserved_gb": round(memory_reserved, 2),
|
||||
"device": torch.cuda.get_device_name(0),
|
||||
}
|
||||
|
||||
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=server_config.host, port=server_config.port)
|
||||
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
DOWNLOAD_FOLDER = "generated_images"
|
||||
WAIT_BEFORE_DOWNLOAD = 2 # seconds
|
||||
|
||||
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def save_from_url(url: str) -> str:
|
||||
"""Download the given URL (relative or absolute) and save it locally."""
|
||||
if url.startswith("/"):
|
||||
direct = BASE_URL.rstrip("/") + url
|
||||
else:
|
||||
direct = url
|
||||
resp = requests.get(direct, timeout=60)
|
||||
resp.raise_for_status()
|
||||
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
|
||||
path = os.path.join(DOWNLOAD_FOLDER, filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return path
|
||||
|
||||
|
||||
def main():
|
||||
payload = {
|
||||
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1,
|
||||
}
|
||||
|
||||
print("Sending request...")
|
||||
try:
|
||||
r = requests.post(SERVER_URL, json=payload, timeout=480)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return
|
||||
|
||||
body = r.json().get("response", [])
|
||||
# Normalize to a list
|
||||
urls = body if isinstance(body, list) else [body] if body else []
|
||||
if not urls:
|
||||
print("No URLs found in the response. Check the server output.")
|
||||
return
|
||||
|
||||
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
|
||||
time.sleep(WAIT_BEFORE_DOWNLOAD)
|
||||
|
||||
for u in urls:
|
||||
try:
|
||||
path = save_from_url(u)
|
||||
print(f"Image saved to: {path}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {u}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
from .requestscopedpipeline import RequestScopedPipeline
|
||||
from .utils import Utils
|
||||
@@ -0,0 +1,296 @@
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
"_offload_device",
|
||||
"_progress_bar_config",
|
||||
"_progress_bar",
|
||||
"_rng_state",
|
||||
"_last_seed",
|
||||
"latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: Any,
|
||||
mutable_attrs: Optional[Iterable[str]] = None,
|
||||
auto_detect_mutables: bool = True,
|
||||
tensor_numel_threshold: int = 1_000_000,
|
||||
tokenizer_lock: Optional[threading.Lock] = None,
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
return None
|
||||
|
||||
if not isinstance(base_sched, BaseAsyncScheduler):
|
||||
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
||||
else:
|
||||
wrapped_scheduler = base_sched
|
||||
|
||||
try:
|
||||
return wrapped_scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
if not self._auto_detect_mutables:
|
||||
return []
|
||||
|
||||
if self._auto_detected_attrs:
|
||||
return self._auto_detected_attrs
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
if name in self._mutable_attrs:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
# try Tensor detection
|
||||
try:
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(candidates) >= max_attrs:
|
||||
break
|
||||
|
||||
self._auto_detected_attrs = candidates
|
||||
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
||||
return self._auto_detected_attrs
|
||||
|
||||
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
||||
try:
|
||||
cls = type(base_obj)
|
||||
descriptor = getattr(cls, attr_name, None)
|
||||
if isinstance(descriptor, property):
|
||||
return descriptor.fset is None
|
||||
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _clone_mutable_attrs(self, base, local):
|
||||
attrs_to_clone = list(self._mutable_attrs)
|
||||
attrs_to_clone.extend(self._autodetect_mutables())
|
||||
|
||||
EXCLUDE_ATTRS = {
|
||||
"components",
|
||||
}
|
||||
|
||||
for attr in attrs_to_clone:
|
||||
if attr in EXCLUDE_ATTRS:
|
||||
logger.debug(f"Skipping excluded attr '{attr}'")
|
||||
continue
|
||||
if not hasattr(base, attr):
|
||||
continue
|
||||
if self._is_readonly_property(base, attr):
|
||||
logger.debug(f"Skipping read-only property '{attr}'")
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(base, attr)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
setattr(local, attr, dict(val))
|
||||
elif isinstance(val, (list, tuple, set)):
|
||||
setattr(local, attr, list(val))
|
||||
elif isinstance(val, bytearray):
|
||||
setattr(local, attr, bytearray(val))
|
||||
else:
|
||||
# small tensors or atomic values
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
setattr(local, attr, val.clone())
|
||||
else:
|
||||
# don't clone big tensors, keep reference
|
||||
setattr(local, attr, val)
|
||||
else:
|
||||
try:
|
||||
setattr(local, attr, copy.copy(val))
|
||||
except Exception:
|
||||
setattr(local, attr, val)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
||||
continue
|
||||
|
||||
def _is_tokenizer_component(self, component) -> bool:
|
||||
if component is None:
|
||||
return False
|
||||
|
||||
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
||||
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
||||
|
||||
class_name = component.__class__.__name__.lower()
|
||||
has_tokenizer_in_name = "tokenizer" in class_name
|
||||
|
||||
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
||||
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
try:
|
||||
local_pipe = copy.copy(self._base)
|
||||
except Exception as e:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
local_scheduler.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
return_scheduler=True,
|
||||
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
||||
)
|
||||
|
||||
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
||||
setattr(local_pipe, "scheduler", final_scheduler)
|
||||
except Exception:
|
||||
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except Exception as e:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
@@ -0,0 +1,141 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseAsyncScheduler:
|
||||
def __init__(self, scheduler: Any):
|
||||
self.scheduler = scheduler
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if hasattr(self.scheduler, name):
|
||||
return getattr(self.scheduler, name)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
if name == "scheduler":
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
|
||||
setattr(self.scheduler, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
|
||||
local = copy.deepcopy(self.scheduler)
|
||||
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
|
||||
cloned = self.__class__(local)
|
||||
return cloned
|
||||
|
||||
def __repr__(self):
|
||||
return f"BaseAsyncScheduler({repr(self.scheduler)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
|
||||
|
||||
|
||||
def async_retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
|
||||
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Backwards compatible: by default the function behaves exactly as before and returns
|
||||
(timesteps_tensor, num_inference_steps)
|
||||
|
||||
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
|
||||
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
|
||||
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
|
||||
(timesteps_tensor, num_inference_steps, scheduler_in_use)
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Optional kwargs:
|
||||
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
|
||||
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
|
||||
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
|
||||
|
||||
Returns:
|
||||
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
|
||||
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
|
||||
"""
|
||||
# pop our optional control kwarg (keeps compatibility)
|
||||
return_scheduler = bool(kwargs.pop("return_scheduler", False))
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
|
||||
# choose scheduler to call set_timesteps on
|
||||
scheduler_in_use = scheduler
|
||||
if return_scheduler:
|
||||
# Do not mutate the provided scheduler: prefer to clone if possible
|
||||
if hasattr(scheduler, "clone_for_request"):
|
||||
try:
|
||||
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
|
||||
scheduler_in_use = scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps or 0, device=device
|
||||
)
|
||||
except Exception:
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
else:
|
||||
# fallback deepcopy (scheduler tends to be smallish - acceptable)
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
|
||||
# helper to test if set_timesteps supports a particular kwarg
|
||||
def _accepts(param_name: str) -> bool:
|
||||
try:
|
||||
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
|
||||
except (ValueError, TypeError):
|
||||
# if signature introspection fails, be permissive and attempt the call later
|
||||
return False
|
||||
|
||||
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = _accepts("timesteps")
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = _accepts("sigmas")
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
else:
|
||||
# default path
|
||||
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
|
||||
if return_scheduler:
|
||||
return timesteps_out, num_inference_steps, scheduler_in_use
|
||||
return timesteps_out, num_inference_steps
|
||||
@@ -0,0 +1,48 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Utils:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
||||
self.service_url = f"http://{host}:{port}"
|
||||
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(self.image_dir):
|
||||
os.makedirs(self.image_dir)
|
||||
|
||||
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
||||
if not os.path.exists(self.video_dir):
|
||||
os.makedirs(self.video_dir)
|
||||
|
||||
def save_image(self, image):
|
||||
if hasattr(image, "to"):
|
||||
try:
|
||||
image = image.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
from torchvision import transforms
|
||||
|
||||
to_pil = transforms.ToPILImage()
|
||||
image = to_pil(image.squeeze(0).clamp(0, 1))
|
||||
|
||||
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(self.image_dir, filename)
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
|
||||
image.save(image_path, format="PNG", optimize=True)
|
||||
|
||||
del image
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return os.path.join(self.service_url, "images", filename)
|
||||
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
@@ -6,4 +6,5 @@ py-consul
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
uvicorn
|
||||
accelerate
|
||||
|
||||
@@ -39,7 +39,7 @@ fsspec==2024.10.0
|
||||
# torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.26.1
|
||||
huggingface-hub==0.35.0
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
|
||||
@@ -102,7 +102,8 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.34.0",
|
||||
"httpx<1.0.0",
|
||||
"huggingface-hub>=0.34.0,<2.0",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
@@ -259,6 +260,7 @@ extras["dev"] = (
|
||||
install_requires = [
|
||||
deps["importlib_metadata"],
|
||||
deps["filelock"],
|
||||
deps["httpx"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["regex"],
|
||||
|
||||
@@ -202,6 +202,7 @@ else:
|
||||
"CogView4Transformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ContextParallelConfig",
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
@@ -229,6 +230,7 @@ else:
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageControlNetModel",
|
||||
@@ -495,6 +497,7 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
@@ -514,6 +517,7 @@ else:
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImagePipeline",
|
||||
@@ -886,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Transformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
ConsistencyDecoderVAE,
|
||||
ContextParallelConfig,
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
@@ -913,6 +918,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageControlNetModel,
|
||||
@@ -1149,6 +1155,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
@@ -1168,6 +1175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -30,11 +30,11 @@ import numpy as np
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import __version__
|
||||
@@ -419,7 +419,7 @@ class ConfigMixin:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
except HfHubHTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
|
||||
@@ -9,7 +9,8 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0",
|
||||
"httpx": "httpx<1.0.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .context_parallel import apply_context_parallel
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..models._modeling_parallel import (
|
||||
ContextParallelConfig,
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
||||
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
|
||||
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier], True, None
|
||||
|
||||
if self.cached_parameter_indices is not None:
|
||||
index = self.cached_parameter_indices.get(identifier, None)
|
||||
if index is None:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
|
||||
return args[index], False, index
|
||||
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
|
||||
if identifier not in self.cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
|
||||
index = self.cached_parameter_indices[identifier]
|
||||
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
|
||||
return args[index], False, index
|
||||
|
||||
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ContextParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
||||
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
|
||||
|
||||
for m in submodule:
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
if isinstance(cp_model_plan, ContextParallelOutput):
|
||||
cp_model_plan = [cp_model_plan]
|
||||
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
|
||||
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
|
||||
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
for m in submodule:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry.remove_hook(hook_name)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
self.module_forward_metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
cls = unwrap_module(module).__class__
|
||||
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
args_list = list(args)
|
||||
|
||||
for name, cpm in self.metadata.items():
|
||||
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
|
||||
continue
|
||||
|
||||
# Maybe the parameter was passed as a keyword argument
|
||||
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
|
||||
name, args_list, kwargs
|
||||
)
|
||||
|
||||
if input_val is None:
|
||||
continue
|
||||
|
||||
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
|
||||
# the output instead of input for a particular layer by setting split_output=True
|
||||
if isinstance(input_val, torch.Tensor):
|
||||
input_val = self._prepare_cp_input(input_val, cpm)
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
if len(input_val) != len(cpm):
|
||||
raise ValueError(
|
||||
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
|
||||
)
|
||||
sharded_input_val = []
|
||||
for i, x in enumerate(input_val):
|
||||
if torch.is_tensor(x) and not cpm[i].split_output:
|
||||
x = self._prepare_cp_input(x, cpm[i])
|
||||
sharded_input_val.append(x)
|
||||
input_val = sharded_input_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(input_val)}")
|
||||
|
||||
if is_kwarg:
|
||||
kwargs[name] = input_val
|
||||
elif index is not None and index < len(args_list):
|
||||
args_list[index] = input_val
|
||||
else:
|
||||
raise ValueError(
|
||||
f"An unexpected error occurred while processing the input '{name}'. Please open an "
|
||||
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
|
||||
f"example along with the full stack trace."
|
||||
)
|
||||
|
||||
return tuple(args_list), kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
|
||||
|
||||
if not is_tensor and not is_tensor_list:
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = [output] if is_tensor else list(output)
|
||||
for index, cpm in self.metadata.items():
|
||||
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
|
||||
continue
|
||||
if index >= len(output):
|
||||
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
|
||||
current_output = output[index]
|
||||
current_output = self._prepare_cp_input(current_output, cpm)
|
||||
output[index] = current_output
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
|
||||
if is_tensor:
|
||||
output = [output]
|
||||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = list(output)
|
||||
|
||||
if len(output) != len(self.metadata):
|
||||
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
|
||||
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
|
||||
class AllGatherFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, dim, group):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = torch.distributed.get_world_size(group)
|
||||
ctx.rank = torch.distributed.get_rank(group)
|
||||
return funcol.all_gather_tensor(tensor, dim, group=group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_chunks[ctx.rank], None, None
|
||||
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
|
||||
# because the alternate case has not yet been tested/required for any model.
|
||||
assert tensor.size()[dim] % mesh.size() == 0, (
|
||||
"Tensor size along dimension to be sharded must be divisible by mesh size"
|
||||
)
|
||||
|
||||
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
|
||||
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
if first_atom == "*":
|
||||
if not isinstance(model, torch.nn.ModuleList):
|
||||
raise ValueError("Wildcard '*' can only be used with ModuleList")
|
||||
submodules = []
|
||||
for submodule in model:
|
||||
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
|
||||
if not isinstance(subsubmodules, list):
|
||||
subsubmodules = [subsubmodules]
|
||||
submodules.extend(subsubmodules)
|
||||
return submodules
|
||||
else:
|
||||
if hasattr(model, first_atom):
|
||||
submodule = getattr(model, first_atom)
|
||||
return _find_submodule_by_name(submodule, remaining_name)
|
||||
else:
|
||||
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
|
||||
@@ -1064,6 +1064,41 @@ class LoraBaseMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
|
||||
pipeline types.
|
||||
"""
|
||||
state_dict = {}
|
||||
final_lora_adapter_metadata = {}
|
||||
|
||||
for prefix, layers in lora_layers.items():
|
||||
state_dict.update(cls.pack_weights(layers, prefix))
|
||||
|
||||
for prefix, metadata in lora_metadata.items():
|
||||
if metadata:
|
||||
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_in_layer",
|
||||
"time_text_embed.guidance_embedder.linear_1",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_out_layer",
|
||||
"time_text_embed.guidance_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_img_in",
|
||||
"x_embedder",
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_txt_in",
|
||||
"context_embedder",
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_in_layer",
|
||||
"time_text_embed.timestep_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_out_layer",
|
||||
"time_text_embed.timestep_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_in_layer",
|
||||
"time_text_embed.text_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_out_layer",
|
||||
"time_text_embed.text_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
|
||||
+256
-2362
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
@@ -119,6 +120,7 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
|
||||
# Experimental changes are subject to change and APIs may break without warning.
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# TODO(aryan): add support for the following:
|
||||
# - Unified Attention
|
||||
# - More dispatcher attention backends
|
||||
# - CFG/Data Parallel
|
||||
# - Tensor Parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextParallelConfig:
|
||||
"""
|
||||
Configuration for context parallelism.
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
|
||||
"""
|
||||
|
||||
ring_degree: Optional[int] = None
|
||||
ulysses_degree: Optional[int] = None
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_local_rank: int = None
|
||||
_ulysses_local_rank: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for applying different parallelisms.
|
||||
|
||||
Args:
|
||||
context_parallel_config (`ContextParallelConfig`, *optional*):
|
||||
Configuration for context parallelism.
|
||||
"""
|
||||
|
||||
context_parallel_config: Optional[ContextParallelConfig] = None
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelInput:
|
||||
"""
|
||||
Configuration for splitting an input tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
split_dim (`int`):
|
||||
The dimension along which to split the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before splitting.
|
||||
split_output (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
|
||||
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
|
||||
RoPE).
|
||||
"""
|
||||
|
||||
split_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
split_output: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelOutput:
|
||||
"""
|
||||
Configuration for gathering an output tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
gather_dim (`int`):
|
||||
The dimension along which to gather the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before gathering.
|
||||
"""
|
||||
|
||||
gather_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
|
||||
|
||||
|
||||
# A dictionary where keys denote the input to be split across context parallel region, and the
|
||||
# value denotes the sharding configuration.
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
|
||||
|
||||
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
||||
#
|
||||
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
|
||||
# tensors at different stages of the forward:
|
||||
#
|
||||
# ```python
|
||||
# _cp_plan = {
|
||||
# "": {
|
||||
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
# },
|
||||
# "pos_embed": {
|
||||
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# },
|
||||
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
# }
|
||||
# ```
|
||||
#
|
||||
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
|
||||
# split/gathered according to this at the respective module level. Here, the following happens:
|
||||
# - "":
|
||||
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
|
||||
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
|
||||
# - "pos_embed":
|
||||
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
|
||||
# we can individually specify how they should be split
|
||||
# - "proj_out":
|
||||
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
|
||||
# layer forward has run).
|
||||
#
|
||||
# ContextParallelInput:
|
||||
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
|
||||
#
|
||||
# ContextParallelOutput:
|
||||
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
||||
@@ -241,7 +241,7 @@ class AttentionModuleMixin:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xops.memory_efficient_attention(q, q, q)
|
||||
_ = xops.ops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -189,15 +192,35 @@ class AutoModel(ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
model_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_or_path,
|
||||
subfolder=subfolder,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
@@ -617,7 +617,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.size(0) > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
is_residual=is_residual,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
self.spatial_compression_ratio = scale_factor_spatial
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor):
|
||||
_, _, num_frame, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
|
||||
@@ -26,11 +26,11 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__, is_torch_available
|
||||
from ..utils import (
|
||||
@@ -385,7 +385,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
except HfHubHTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
|
||||
@@ -65,6 +65,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
_parallel_config = None
|
||||
_cp_plan = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -620,8 +623,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||
the torch native scaled dot product attention.
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||
set, or the torch native scaled dot product attention.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -960,6 +963,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1340,6 +1344,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if parallel_config is not None:
|
||||
model.enable_parallelism(config=parallel_config)
|
||||
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
@@ -1478,6 +1485,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
def enable_parallelism(
|
||||
self,
|
||||
*,
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_parallel_config"):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
|
||||
return selected_indices
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
id_cond: Optional[torch.Tensor] = None,
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
|
||||
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
class BriaAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -161,7 +162,12 @@ class BriaAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -472,7 +478,7 @@ class BriaSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
@@ -588,7 +594,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`BriaTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`CogView3PlusTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
|
||||
|
||||
class FluxAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -114,7 +116,12 @@ class FluxAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
@@ -556,6 +566,15 @@ class FluxTransformer2DModel(
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wtype = hidden_states.dtype
|
||||
(
|
||||
shift_msa_i,
|
||||
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self.block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
|
||||
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -334,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
@@ -502,6 +505,18 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
|
||||
|
||||
class SkyReelsV2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
|
||||
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -132,6 +134,7 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -144,6 +147,7 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -539,6 +543,19 @@ class WanTransformer3DModel(
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
},
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import get_device
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -161,7 +162,9 @@ class AutoOffloadStrategy:
|
||||
|
||||
current_module_size = model.get_memory_footprint()
|
||||
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -301,7 +304,7 @@ class ComponentsManager:
|
||||
cm.add("vae", vae_model, collection="sdxl")
|
||||
|
||||
# Enable auto offloading
|
||||
cm.enable_auto_cpu_offload(device="cuda")
|
||||
cm.enable_auto_cpu_offload()
|
||||
|
||||
# Retrieve components
|
||||
unet = cm.get_one(name="unet", collection="sdxl")
|
||||
@@ -490,6 +493,8 @@ class ComponentsManager:
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# YiYi TODO: rename to search_components for now, may remove this method
|
||||
def search_components(
|
||||
@@ -678,7 +683,7 @@ class ComponentsManager:
|
||||
|
||||
return get_return_dict(matches, return_dict_with_names)
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
|
||||
"""
|
||||
Enable automatic CPU offloading for all components.
|
||||
|
||||
@@ -704,6 +709,8 @@ class ComponentsManager:
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
|
||||
@@ -32,6 +32,8 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
default_blocks_name = "FluxAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@@ -0,0 +1,763 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Simple typed wrapper for parameter overrides
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
|
||||
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash
|
||||
from .modular_pipeline import ModularPipelineBlocks
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"}
|
||||
|
||||
|
||||
# Mellon Input Parameters (runtime parameters, not models)
|
||||
MELLON_INPUT_PARAMS = {
|
||||
# controlnet
|
||||
"control_image": {
|
||||
"label": "Control Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
"controlnet_conditioning_scale": {
|
||||
"label": "Scale",
|
||||
"type": "float",
|
||||
"default": 0.5,
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
"control_guidance_end": {
|
||||
"label": "End",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
"control_guidance_start": {
|
||||
"label": "Start",
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
"controlnet": {
|
||||
"label": "Controlnet",
|
||||
"type": "custom_controlnet",
|
||||
"display": "input",
|
||||
},
|
||||
"embeddings": {
|
||||
"label": "Text Embeddings",
|
||||
"display": "input",
|
||||
"type": "embeddings",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"display": "textarea",
|
||||
},
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"display": "textarea",
|
||||
},
|
||||
"guidance_scale": {
|
||||
"label": "Guidance Scale",
|
||||
"type": "float",
|
||||
"display": "slider",
|
||||
"default": 5,
|
||||
"min": 1.0,
|
||||
"max": 30.0,
|
||||
"step": 0.1,
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"default": 1024,
|
||||
"min": 64,
|
||||
"step": 8,
|
||||
},
|
||||
"image_latents": {
|
||||
"label": "Image Latents",
|
||||
"type": "latents",
|
||||
"display": "input",
|
||||
"onChange": {False: ["height", "width"], True: ["strength"]},
|
||||
},
|
||||
"latents": {
|
||||
"label": "Latents",
|
||||
"type": "latents",
|
||||
"display": "input",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"display": "slider",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"display": "random",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967295,
|
||||
},
|
||||
"strength": {
|
||||
"label": "Strength",
|
||||
"type": "float",
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"default": 1024,
|
||||
"min": 64,
|
||||
"step": 8,
|
||||
},
|
||||
"ip_adapter": {
|
||||
"label": "IP Adapter",
|
||||
"type": "custom_ip_adapter",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
# Mellon Model Parameters (diffusers_auto_model types)
|
||||
MELLON_MODEL_PARAMS = {
|
||||
"scheduler": {
|
||||
"label": "Scheduler",
|
||||
"display": "input",
|
||||
"type": "diffusers_auto_model",
|
||||
},
|
||||
"text_encoders": {
|
||||
"label": "Text Encoders",
|
||||
"type": "diffusers_auto_models",
|
||||
"display": "input",
|
||||
},
|
||||
"unet": {
|
||||
"label": "Unet",
|
||||
"display": "input",
|
||||
"type": "diffusers_auto_model",
|
||||
"onSignal": {
|
||||
"action": "signal",
|
||||
"target": "guider",
|
||||
},
|
||||
},
|
||||
"guider": {
|
||||
"label": "Guider",
|
||||
"display": "input",
|
||||
"type": "custom_guider",
|
||||
"onChange": {False: ["guidance_scale"], True: []},
|
||||
},
|
||||
"vae": {
|
||||
"label": "VAE",
|
||||
"display": "input",
|
||||
"type": "diffusers_auto_model",
|
||||
},
|
||||
"controlnet": {
|
||||
"label": "Controlnet Model",
|
||||
"type": "diffusers_auto_model",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
# Mellon Output Parameters (display = "output")
|
||||
MELLON_OUTPUT_PARAMS = {
|
||||
"embeddings": {
|
||||
"label": "Text Embeddings",
|
||||
"display": "output",
|
||||
"type": "embeddings",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image_latents": {
|
||||
"label": "Image Latents",
|
||||
"type": "latents",
|
||||
"display": "output",
|
||||
},
|
||||
"latents": {
|
||||
"label": "Latents",
|
||||
"type": "latents",
|
||||
"display": "output",
|
||||
},
|
||||
"latents_preview": {
|
||||
"label": "Latents Preview",
|
||||
"display": "output",
|
||||
"type": "latent",
|
||||
},
|
||||
"controlnet_out": {
|
||||
"label": "Controlnet",
|
||||
"display": "output",
|
||||
"type": "controlnet",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Default param selections per supported node_type
|
||||
# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS.
|
||||
NODE_TYPE_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet",
|
||||
],
|
||||
"block_names": ["controlnet_vae_encoder"],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
# custom adapters coming in as inputs
|
||||
"controlnet",
|
||||
# ip_adapter is optional and custom; include if available
|
||||
"ip_adapter",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
"block_names": ["vae_encoder"],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
# optional image prompt input supported in embeddings node
|
||||
"image",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
"block_names": ["text_encoder"],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
"block_names": ["decode"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MellonParam:
|
||||
name: str
|
||||
label: str
|
||||
type: str
|
||||
display: Optional[str] = None
|
||||
default: Any = None
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
step: Optional[float] = None
|
||||
options: Any = None
|
||||
value: Any = None
|
||||
fieldOptions: Optional[Dict[str, Any]] = None
|
||||
onChange: Any = None
|
||||
onSignal: Any = None
|
||||
_map_to_input: Any = None # the block input name this parameter maps to
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
return {k: v for k, v in data.items() if not k.startswith("_") and v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MellonNodeConfig(PushToHubMixin):
|
||||
"""
|
||||
A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
inputs: List[Union[str, MellonParam]]
|
||||
model_inputs: List[Union[str, MellonParam]]
|
||||
outputs: List[Union[str, MellonParam]]
|
||||
blocks_names: list[str]
|
||||
node_type: str
|
||||
config_name = "mellon_config.json"
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.inputs, list):
|
||||
self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS)
|
||||
if isinstance(self.model_inputs, list):
|
||||
self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS)
|
||||
if isinstance(self.outputs, list):
|
||||
self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_params_list(
|
||||
params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
def _resolve_param(
|
||||
param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
if isinstance(param, str):
|
||||
if param not in default_params_map:
|
||||
raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead")
|
||||
return param, default_params_map[param].copy()
|
||||
elif isinstance(param, MellonParam):
|
||||
param_dict = param.to_dict()
|
||||
param_name = param_dict.pop("name")
|
||||
return param_name, param_dict
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead"
|
||||
)
|
||||
|
||||
resolved = {}
|
||||
for p in params:
|
||||
logger.info(f" Resolving param: {p}")
|
||||
name, cfg = _resolve_param(p, default_map)
|
||||
if name in resolved:
|
||||
raise ValueError(f"Duplicate param '{name}'")
|
||||
resolved[name] = cfg
|
||||
return resolved
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def load_mellon_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
return_unused_kwargs=False,
|
||||
return_commit_hash=False,
|
||||
**kwargs,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Load a model or scheduler configuration.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
||||
[`~ConfigMixin.save_config`].
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
||||
Whether unused keyword arguments of the config are returned.
|
||||
return_commit_hash (`bool`, *optional*, defaults to `False):
|
||||
Whether the `commit_hash` of the loaded configuration are returned.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
A dictionary of all the parameters stored in a JSON configuration file.
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_dir = kwargs.pop("local_dir", None)
|
||||
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
|
||||
if cls.config_name is None:
|
||||
raise ValueError(
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
||||
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
||||
" token having permission to this repo with `token` or log in with `hf auth login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
||||
" this model name. Check the model page at"
|
||||
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HfHubHTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
||||
" run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
config_dict = json.loads(text)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
if not (return_unused_kwargs or return_commit_hash):
|
||||
return config_dict
|
||||
|
||||
outputs = (config_dict,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
outputs += (kwargs,)
|
||||
|
||||
if return_commit_hash:
|
||||
outputs += (commit_hash,)
|
||||
|
||||
return outputs
|
||||
|
||||
def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save the Mellon node definition to a JSON file.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_config`
|
||||
output_config_file = os.path.join(save_directory, self.config_name)
|
||||
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"Mellon node definition saved in {output_config_file}")
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save the Mellon schema dictionary to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file to save a configuration instance's parameters.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
def to_json_string(self) -> str:
|
||||
"""
|
||||
Serializes this instance to a JSON string of the Mellon schema dict.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
|
||||
mellon_dict = self.to_mellon_dict()
|
||||
return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_mellon_dict(self) -> Dict[str, Any]:
|
||||
"""Return a JSON-serializable dict focusing on the Mellon schema fields only.
|
||||
|
||||
params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}.
|
||||
"""
|
||||
# inputs/model_inputs/outputs are already normalized dicts
|
||||
merged_params = {}
|
||||
merged_params.update(self.inputs or {})
|
||||
merged_params.update(self.model_inputs or {})
|
||||
merged_params.update(self.outputs or {})
|
||||
|
||||
return {
|
||||
"node_type": self.node_type,
|
||||
"blocks_names": self.blocks_names,
|
||||
"params": merged_params,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
|
||||
"""Create a config from a Mellon schema dict produced by to_mellon_dict().
|
||||
|
||||
Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from
|
||||
MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by
|
||||
default.
|
||||
"""
|
||||
flat_params = mellon_dict.get("params", {})
|
||||
|
||||
inputs: Dict[str, Any] = {}
|
||||
model_inputs: Dict[str, Any] = {}
|
||||
outputs: Dict[str, Any] = {}
|
||||
|
||||
for param_name, param_dict in flat_params.items():
|
||||
if param_dict.get("display", "") == "output":
|
||||
outputs[param_name] = param_dict
|
||||
elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"):
|
||||
model_inputs[param_name] = param_dict
|
||||
else:
|
||||
inputs[param_name] = param_dict
|
||||
|
||||
return cls(
|
||||
inputs=inputs,
|
||||
model_inputs=model_inputs,
|
||||
outputs=outputs,
|
||||
blocks_names=mellon_dict.get("blocks_names", []),
|
||||
node_type=mellon_dict.get("node_type"),
|
||||
)
|
||||
|
||||
# YiYi Notes: not used yet
|
||||
@classmethod
|
||||
def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig":
|
||||
"""
|
||||
Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
|
||||
use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
|
||||
"""
|
||||
if node_type not in NODE_TYPE_PARAMS_MAP:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
blocks_names = list(blocks.sub_blocks.keys())
|
||||
|
||||
default_node_config = NODE_TYPE_PARAMS_MAP[node_type]
|
||||
inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", [])
|
||||
model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", [])
|
||||
outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", [])
|
||||
|
||||
for required_input_name in blocks.required_inputs:
|
||||
if required_input_name not in inputs_list:
|
||||
inputs_list.append(
|
||||
MellonParam(
|
||||
name=required_input_name, label=required_input_name, type=required_input_name, display="input"
|
||||
)
|
||||
)
|
||||
|
||||
for component_spec in blocks.expected_components:
|
||||
if component_spec.name not in model_inputs_list:
|
||||
model_inputs_list.append(
|
||||
MellonParam(
|
||||
name=component_spec.name,
|
||||
label=component_spec.name,
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
inputs=inputs_list,
|
||||
model_inputs=model_inputs_list,
|
||||
outputs=outputs_list,
|
||||
blocks_names=blocks_names,
|
||||
node_type=node_type,
|
||||
)
|
||||
|
||||
|
||||
# Minimal modular registry for Mellon node configs
|
||||
class ModularMellonNodeRegistry:
|
||||
"""Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
self._initialized = False
|
||||
|
||||
def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
|
||||
if not self._initialized:
|
||||
_initialize_registry(self)
|
||||
self._registry[pipeline_cls] = node_params
|
||||
|
||||
def get(self, pipeline_cls: type) -> MellonNodeConfig:
|
||||
if not self._initialized:
|
||||
_initialize_registry(self)
|
||||
return self._registry.get(pipeline_cls, None)
|
||||
|
||||
def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
|
||||
if not self._initialized:
|
||||
_initialize_registry(self)
|
||||
return self._registry
|
||||
|
||||
|
||||
def _register_preset_node_types(
|
||||
pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
|
||||
):
|
||||
"""Register all node-type presets for a given pipeline class from a params map."""
|
||||
node_configs = {}
|
||||
for node_type, spec in params_map.items():
|
||||
node_config = MellonNodeConfig(
|
||||
inputs=spec.get("inputs", []),
|
||||
model_inputs=spec.get("model_inputs", []),
|
||||
outputs=spec.get("outputs", []),
|
||||
blocks_names=spec.get("block_names", []),
|
||||
node_type=node_type,
|
||||
)
|
||||
node_configs[node_type] = node_config
|
||||
registry.register(pipeline_cls, node_configs)
|
||||
|
||||
|
||||
def _initialize_registry(registry: ModularMellonNodeRegistry):
|
||||
"""Initialize the registry and register all available pipeline configs."""
|
||||
print("Initializing registry")
|
||||
|
||||
registry._initialized = True
|
||||
|
||||
try:
|
||||
from .qwenimage.modular_pipeline import QwenImageModularPipeline
|
||||
from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
|
||||
|
||||
_register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
|
||||
except Exception:
|
||||
raise Exception("Failed to register QwenImageModularPipeline")
|
||||
|
||||
try:
|
||||
from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
|
||||
from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
|
||||
|
||||
_register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
|
||||
except Exception:
|
||||
raise Exception("Failed to register StableDiffusionXLModularPipeline")
|
||||
@@ -51,6 +51,7 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# map regular pipeline to modular pipeline class name
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
@@ -61,16 +62,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
|
||||
[
|
||||
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
|
||||
("WanModularPipeline", "WanAutoBlocks"),
|
||||
("FluxModularPipeline", "FluxAutoBlocks"),
|
||||
("QwenImageModularPipeline", "QwenImageAutoBlocks"),
|
||||
("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineState:
|
||||
@@ -323,7 +314,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
@@ -423,7 +414,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# if it is a kwargs type, e.g. "denoiser_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
@@ -1454,6 +1445,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_model_index.json"
|
||||
hf_device_map = None
|
||||
default_blocks_name = None
|
||||
|
||||
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
||||
def __init__(
|
||||
@@ -1514,7 +1506,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
`_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
|
||||
@@ -1,665 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..image_processor import PipelineImageInput
|
||||
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam(
|
||||
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
||||
),
|
||||
"prompt_2": InputParam(
|
||||
"prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
||||
),
|
||||
"negative_prompt": InputParam(
|
||||
"negative_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts not to guide the image generation",
|
||||
),
|
||||
"negative_prompt_2": InputParam(
|
||||
"negative_prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The negative prompt or prompts for text_encoder_2",
|
||||
),
|
||||
"cross_attention_kwargs": InputParam(
|
||||
"cross_attention_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="Kwargs dictionary passed to the AttentionProcessor",
|
||||
),
|
||||
"clip_skip": InputParam(
|
||||
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
||||
),
|
||||
"image": InputParam(
|
||||
"image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to modify for img2img or inpainting",
|
||||
),
|
||||
"mask_image": InputParam(
|
||||
"mask_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Mask image for inpainting, white pixels will be repainted",
|
||||
),
|
||||
"generator": InputParam(
|
||||
"generator",
|
||||
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
description="Generator(s) for deterministic generation",
|
||||
),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam(
|
||||
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
||||
),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
||||
),
|
||||
"timesteps": InputParam(
|
||||
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
||||
),
|
||||
"sigmas": InputParam(
|
||||
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
||||
),
|
||||
"denoising_end": InputParam(
|
||||
"denoising_end",
|
||||
type_hint=Optional[float],
|
||||
description="Fraction of denoising process to complete before termination",
|
||||
),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam(
|
||||
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
||||
),
|
||||
"denoising_start": InputParam(
|
||||
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
||||
),
|
||||
"padding_mask_crop": InputParam(
|
||||
"padding_mask_crop",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Size of margin in crop for image and mask",
|
||||
),
|
||||
"original_size": InputParam(
|
||||
"original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Original size of the image for SDXL's micro-conditioning",
|
||||
),
|
||||
"target_size": InputParam(
|
||||
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
||||
),
|
||||
"negative_original_size": InputParam(
|
||||
"negative_original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on image resolution",
|
||||
),
|
||||
"negative_target_size": InputParam(
|
||||
"negative_target_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on target resolution",
|
||||
),
|
||||
"crops_coords_top_left": InputParam(
|
||||
"crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Top-left coordinates for SDXL's micro-conditioning",
|
||||
),
|
||||
"negative_crops_coords_top_left": InputParam(
|
||||
"negative_crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Negative conditioning crop coordinates",
|
||||
),
|
||||
"aesthetic_score": InputParam(
|
||||
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
||||
),
|
||||
"negative_aesthetic_score": InputParam(
|
||||
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
||||
),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam(
|
||||
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
||||
),
|
||||
"ip_adapter_image": InputParam(
|
||||
"ip_adapter_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Image(s) to be used as IP adapter",
|
||||
),
|
||||
"control_image": InputParam(
|
||||
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
||||
),
|
||||
"control_guidance_start": InputParam(
|
||||
"control_guidance_start",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=0.0,
|
||||
description="When ControlNet starts applying",
|
||||
),
|
||||
"control_guidance_end": InputParam(
|
||||
"control_guidance_end",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="When ControlNet stops applying",
|
||||
),
|
||||
"controlnet_conditioning_scale": InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="Scale factor for ControlNet outputs",
|
||||
),
|
||||
"guess_mode": InputParam(
|
||||
"guess_mode",
|
||||
type_hint=bool,
|
||||
default=False,
|
||||
description="Enables ControlNet encoder to recognize input without prompts",
|
||||
),
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
required=True,
|
||||
description="Text embeddings used to guide image generation",
|
||||
),
|
||||
"negative_prompt_embeds": InputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": InputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": InputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
"image_latents": InputParam(
|
||||
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
||||
),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"add_time_ids": InputParam(
|
||||
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
||||
),
|
||||
"negative_add_time_ids": InputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": InputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
required=True,
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||
|
||||
|
||||
DEFAULT_PARAM_MAPS = {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "a bear sitting in a chair drinking a milkshake",
|
||||
"display": "textarea",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
"display": "textarea",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 1000,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"display": "random",
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TYPE_MAPS = {
|
||||
"int": {
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
},
|
||||
"float": {
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
},
|
||||
"str": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
"bool": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
},
|
||||
"image": {
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||
"text_encoders": ["text_encoder", "tokenizer"],
|
||||
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||
"prompt_embeddings": ["prompt_embeds"],
|
||||
}
|
||||
|
||||
|
||||
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
"""
|
||||
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
|
||||
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||
"""
|
||||
if name is None:
|
||||
return None
|
||||
for group_name, group_keys in group_params_keys.items():
|
||||
for group_key in group_keys:
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
"""
|
||||
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
|
||||
around a ModularPipelineBlocks object.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
config_name = "node_config.json"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
self.blocks = blocks
|
||||
|
||||
if label is None:
|
||||
label = self.blocks.__class__.__name__
|
||||
# blocks param name -> mellon param name
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||
input_params[inp.name] = param
|
||||
mellon_name = param.pop("name", inp.name)
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
elif get_group_name(inp.name):
|
||||
param = get_group_name(inp.name)
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).lower()
|
||||
else:
|
||||
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||
if type_key in type_str:
|
||||
param = type_param.copy()
|
||||
param["label"] = inp.name
|
||||
param["display"] = "input"
|
||||
break
|
||||
else:
|
||||
param = inp.name
|
||||
# add the param dict to the inp_params dict
|
||||
input_params[inp.name] = param
|
||||
|
||||
component_params = {}
|
||||
for comp in self.blocks.expected_components:
|
||||
param = kwargs.pop(comp.name, None)
|
||||
if param:
|
||||
component_params[comp.name] = param
|
||||
mellon_name = param.pop("name", comp.name)
|
||||
if mellon_name != comp.name:
|
||||
self.name_mapping[comp.name] = mellon_name
|
||||
continue
|
||||
|
||||
to_exclude = False
|
||||
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||
if exclude_key in comp.name:
|
||||
to_exclude = True
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
self.name_mapping[comp.name] = param
|
||||
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||
else:
|
||||
param = comp.name
|
||||
# add the param dict to the model_params dict
|
||||
component_params[comp.name] = param
|
||||
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
|
||||
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediate_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
output_params[out.name] = param
|
||||
mellon_name = param.pop("name", out.name)
|
||||
if mellon_name != out.name:
|
||||
self.name_mapping[out.name] = mellon_name
|
||||
continue
|
||||
|
||||
if out.name in DEFAULT_PARAM_MAPS:
|
||||
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||
param["display"] = "output"
|
||||
else:
|
||||
group_name = get_group_name(out.name)
|
||||
if group_name:
|
||||
param = group_name
|
||||
if out.name not in self.name_mapping:
|
||||
self.name_mapping[out.name] = param
|
||||
else:
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
"input_params": input_params,
|
||||
"component_params": component_params,
|
||||
"output_params": output_params,
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components_manager, collection=None):
|
||||
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
|
||||
self._components_manager = components_manager
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
node = {}
|
||||
node["label"] = self.config.label
|
||||
node["category"] = self.config.category
|
||||
|
||||
node_param = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
if inp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[inp_name]
|
||||
else:
|
||||
mellon_name = inp_name
|
||||
if isinstance(inp_param, str):
|
||||
param = {
|
||||
"label": inp_param,
|
||||
"type": inp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
else:
|
||||
mellon_name = comp_name
|
||||
if isinstance(comp_param, str):
|
||||
param = {
|
||||
"label": comp_param,
|
||||
"type": comp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
else:
|
||||
mellon_name = out_name
|
||||
if isinstance(out_param, str):
|
||||
param = {
|
||||
"label": out_param,
|
||||
"type": out_param,
|
||||
"display": "output",
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||
node["params"] = node_param
|
||||
return node
|
||||
|
||||
def save_mellon_config(self, file_path):
|
||||
"""
|
||||
Save the Mellon configuration to a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path where the JSON file will be saved
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_mellon_config(cls, file_path):
|
||||
"""
|
||||
Load a Mellon configuration from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||
|
||||
Returns:
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
params_components = {}
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
logger.debug(f"component: {comp_name}")
|
||||
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||
if mellon_comp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||
else:
|
||||
comp = kwargs.pop(mellon_comp_name)
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||
if mellon_inp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||
else:
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
self.pipeline.update_components(**params_components)
|
||||
output = self.pipeline(**params_run, output=return_output_names)
|
||||
return output
|
||||
@@ -577,9 +577,8 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(
|
||||
name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
|
||||
),
|
||||
InputParam(name="image_height", required=True),
|
||||
InputParam(name="image_width", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
@@ -612,10 +611,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# for edit, image size can be different from the target size (height/width)
|
||||
image = (
|
||||
block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
|
||||
)
|
||||
image_width, image_height = image.size
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
@@ -624,7 +619,11 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
),
|
||||
(1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
|
||||
(
|
||||
1,
|
||||
block_state.image_height // components.vae_scale_factor // 2,
|
||||
block_state.image_width // components.vae_scale_factor // 2,
|
||||
),
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
|
||||
@@ -496,7 +496,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
negative_prompt = block_state.negative_prompt or " "
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
|
||||
@@ -307,6 +307,13 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
@@ -327,6 +334,11 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
|
||||
@@ -511,17 +511,42 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageAutoInputStep,
|
||||
QwenImageOptionalControlNetInputStep,
|
||||
QwenImageAutoBeforeDenoiseStep,
|
||||
QwenImageOptionalControlNetBeforeDenoiseStep,
|
||||
QwenImageAutoDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
|
||||
+ " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
|
||||
+ " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
|
||||
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
|
||||
## 1.10 QwenImage/auto block & presets
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("input", QwenImageAutoInputStep()),
|
||||
("controlnet_input", QwenImageOptionalControlNetInputStep()),
|
||||
("before_denoise", QwenImageAutoBeforeDenoiseStep()),
|
||||
("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
|
||||
("denoise", QwenImageAutoDenoiseStep()),
|
||||
("denoise", QwenImageCoreDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -699,7 +724,7 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
class QwenImageEditAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -800,13 +825,34 @@ class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
|
||||
|
||||
## 2.7 QwenImage-Edit/auto blocks & presets
|
||||
|
||||
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditAutoInputStep,
|
||||
QwenImageEditAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
|
||||
+ "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n"
|
||||
+ " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n"
|
||||
+ " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
|
||||
)
|
||||
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("input", QwenImageEditAutoInputStep()),
|
||||
("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
|
||||
("denoise", QwenImageEditAutoDenoiseStep()),
|
||||
("denoise", QwenImageEditCoreDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -104,6 +104,8 @@ class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
@@ -158,6 +160,8 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageEditAutoBlocks"
|
||||
|
||||
# YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
|
||||
@property
|
||||
def default_height(self):
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# mellon nodes
|
||||
QwenImage_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": ["controlnet_vae_encoder"],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
"controlnet",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -262,37 +262,37 @@ class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="image embeddings for IP-Adapter",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative image embeddings for IP-Adapter",
|
||||
),
|
||||
]
|
||||
@@ -1120,13 +1120,13 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineB
|
||||
OutputParam(
|
||||
"add_time_ids",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_add_time_ids",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The negative time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
|
||||
@@ -1331,13 +1331,13 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"add_time_ids",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_add_time_ids",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The negative time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
|
||||
|
||||
@@ -183,14 +183,14 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, "
|
||||
"add_time_ids/negative_add_time_ids, "
|
||||
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
|
||||
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
|
||||
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
"please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -307,14 +307,14 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, "
|
||||
"add_time_ids/negative_add_time_ids, "
|
||||
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
|
||||
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
|
||||
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
"please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
|
||||
@@ -258,25 +258,25 @@ class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="negative pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -82,19 +82,17 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
# before_denoise: text2img
|
||||
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
@@ -104,19 +102,17 @@ class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
# before_denoise: img2img
|
||||
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
@@ -126,19 +122,17 @@ class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
# before_denoise: inpainting
|
||||
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
@@ -255,25 +249,48 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "controlnet_input", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
|
||||
+ " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n"
|
||||
+ "- for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image_latents`\n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
|
||||
+ "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is prompt embeddings\n"
|
||||
)
|
||||
|
||||
|
||||
# ip-adapter, controlnet, text2img, img2img, inpainting
|
||||
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
StableDiffusionXLCoreDenoiseStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"ip_adapter",
|
||||
"image_encoder",
|
||||
"before_denoise",
|
||||
"controlnet_input",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decoder",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -321,7 +338,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("vae_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
@@ -334,7 +351,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||
("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
|
||||
@@ -361,10 +378,8 @@ AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
|
||||
("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
|
||||
("denoise", StableDiffusionXLAutoDenoiseStep),
|
||||
("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||
("denoise", StableDiffusionXLCoreDenoiseStep),
|
||||
("decode", StableDiffusionXLAutoDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -54,6 +54,8 @@ class StableDiffusionXLModularPipeline(
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
default_blocks_name = "StableDiffusionXLAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
SDXL_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": [None],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
# custom adapters coming in as inputs
|
||||
"controlnet",
|
||||
# ip_adapter is optional and custom; include if available
|
||||
"ip_adapter",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
"block_names": ["vae_encoder"],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
"block_names": ["text_encoder"],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
"block_names": ["decode"],
|
||||
},
|
||||
}
|
||||
@@ -146,13 +146,13 @@ class WanInputStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -79,11 +79,11 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds. "
|
||||
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
"Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -89,13 +89,13 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -37,6 +37,8 @@ class WanModularPipeline(
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_height * self.vae_scale_factor_spatial
|
||||
|
||||
@@ -285,6 +285,7 @@ else:
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -393,6 +394,7 @@ else:
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
@@ -682,6 +684,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
@@ -719,6 +722,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -238,7 +238,7 @@ class ChromaPipeline(
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
@@ -246,7 +246,7 @@ class ChromaPipeline(
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
@@ -605,10 +605,9 @@ class ChromaPipeline(
|
||||
|
||||
# Extend the prompt attention mask to account for image tokens in the final sequence
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
||||
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
|
||||
dim=1,
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
@@ -688,11 +687,11 @@ class ChromaPipeline(
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -749,12 +749,12 @@ class ChromaImg2ImgPipeline(
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
strength (`float, *optional*, defaults to 0.9):
|
||||
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
||||
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lucy_edit import LucyEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,735 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The Decart AI Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import LucyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> from typing import List
|
||||
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from diffusers import AutoencoderKLWan, LucyEditPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> # Arguments
|
||||
>>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
|
||||
>>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
|
||||
>>> negative_prompt = ""
|
||||
>>> num_frames = 81
|
||||
>>> height = 480
|
||||
>>> width = 832
|
||||
|
||||
|
||||
>>> # Load video
|
||||
>>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
|
||||
... video = load_video(url)[:num_frames]
|
||||
... video = [video[i].resize((width, height)) for i in range(num_frames)]
|
||||
... return video
|
||||
|
||||
|
||||
>>> video = load_video(url, convert_method=convert_video)
|
||||
|
||||
>>> # Load model
|
||||
>>> model_id = "decart-ai/Lucy-Edit-Dev"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Generate video
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... video=video,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=480,
|
||||
... width=832,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
|
||||
>>> # Export video
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for video-to-video generation using Lucy Edit.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
|
||||
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
|
||||
stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: Optional[WanTransformer3DModel] = None,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.register_to_config(expand_timesteps=expand_timesteps)
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if video is None:
|
||||
raise ValueError("`video` is required, received None.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_latent_frames = (
|
||||
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
|
||||
)
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
# Prepare noise latents
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# Prepare condition latents
|
||||
condition_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
|
||||
]
|
||||
|
||||
condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
device, dtype
|
||||
)
|
||||
|
||||
condition_latents = (condition_latents - latents_mean) * latents_std
|
||||
|
||||
# Check shapes
|
||||
assert latents.shape == condition_latents.shape, (
|
||||
f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
|
||||
)
|
||||
|
||||
return latents, condition_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: List[Image.Image],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[Image.Image]`):
|
||||
The video to use as the condition for the video generation.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~LucyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = (
|
||||
self.transformer.config.out_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.out_channels
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition_latents = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
# latent_model_input = latents.to(transformer_dtype)
|
||||
latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
|
||||
# latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LucyPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LucyPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Lucy pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -19,12 +19,12 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -1102,7 +1110,7 @@ def _download_dduf_file(
|
||||
if not local_files_only:
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
@@ -23,6 +23,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import requests
|
||||
@@ -36,9 +37,8 @@ from huggingface_hub import (
|
||||
read_dduf_file,
|
||||
snapshot_download,
|
||||
)
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -1616,7 +1616,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not local_files_only:
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
@@ -28,6 +28,7 @@ else:
|
||||
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
|
||||
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
||||
|
||||
@@ -43,6 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
|
||||
from .pipeline_qwenimage_edit import QwenImageEditPipeline
|
||||
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
|
||||
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
||||
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
|
||||
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
|
||||
else:
|
||||
|
||||
@@ -208,7 +208,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.vl_processor = processor
|
||||
self.tokenizer_max_length = 1024
|
||||
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
@@ -0,0 +1,883 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import QwenImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import QwenImageEditPlusPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
||||
... ).convert("RGB")
|
||||
>>> prompt = (
|
||||
... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
|
||||
... )
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("qwenimage_edit_plus.png")
|
||||
```
|
||||
"""
|
||||
|
||||
CONDITION_IMAGE_SIZE = 384 * 384
|
||||
VAE_IMAGE_SIZE = 1024 * 1024
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def calculate_dimensions(target_area, ratio):
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
r"""
|
||||
The Qwen-Image-Edit pipeline for image editing.
|
||||
|
||||
Args:
|
||||
transformer ([`QwenImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`QwenTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
processor: Qwen2VLProcessor,
|
||||
transformer: QwenImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 1024
|
||||
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_template_encode_start_idx = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
||||
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
return split_result
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
if isinstance(image, list):
|
||||
base_img_prompt = ""
|
||||
for i, img in enumerate(image):
|
||||
base_img_prompt += img_prompt_template.format(i + 1)
|
||||
elif image is not None:
|
||||
base_img_prompt = img_prompt_template.format(1)
|
||||
else:
|
||||
base_img_prompt = ""
|
||||
|
||||
template = self.prompt_template_encode
|
||||
|
||||
drop_idx = self.prompt_template_encode_start_idx
|
||||
txt = [template.format(base_img_prompt + e) for e in prompt]
|
||||
|
||||
model_inputs = self.processor(
|
||||
text=txt,
|
||||
images=image,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids=model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
pixel_values=model_inputs.pixel_values,
|
||||
image_grid_thw=model_inputs.image_grid_thw,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
image (`torch.Tensor`, *optional*):
|
||||
image to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
image_latents = (image_latents - latents_mean) / latents_std
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
images,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
image_latents = None
|
||||
if images is not None:
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
all_image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
image_latent_height, image_latent_width = image_latents.shape[3:]
|
||||
image_latents = self._pack_latents(
|
||||
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
||||
)
|
||||
all_image_latents.append(image_latents)
|
||||
image_latents = torch.cat(all_image_latents, dim=1)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
true_cfg_scale: float = 4.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||
not greater than `1`).
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
||||
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
||||
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
||||
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
||||
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
||||
lower image quality.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
||||
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
||||
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
||||
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
||||
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
||||
enable classifier-free guidance computations).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
image_size = image[-1].size if isinstance(image, list) else image.size
|
||||
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
||||
height = height or calculated_height
|
||||
width = width or calculated_width
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
width = width // multiple_of * multiple_of
|
||||
height = height // multiple_of * multiple_of
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
condition_image_sizes = []
|
||||
condition_images = []
|
||||
vae_image_sizes = []
|
||||
vae_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
condition_width, condition_height = calculate_dimensions(
|
||||
CONDITION_IMAGE_SIZE, image_width / image_height
|
||||
)
|
||||
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
|
||||
condition_image_sizes.append((condition_width, condition_height))
|
||||
vae_image_sizes.append((vae_width, vae_height))
|
||||
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
||||
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
|
||||
if true_cfg_scale > 1 and not has_neg_prompt:
|
||||
logger.warning(
|
||||
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
||||
)
|
||||
elif true_cfg_scale <= 1 and has_neg_prompt:
|
||||
logger.warning(
|
||||
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
||||
)
|
||||
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
image=condition_images,
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
image=condition_images,
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, image_latents = self.prepare_latents(
|
||||
vae_images,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
img_shapes = [
|
||||
[
|
||||
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
||||
for vae_width, vae_height in vae_image_sizes
|
||||
],
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
||||
elif self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
||||
logger.warning(
|
||||
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
||||
)
|
||||
guidance = None
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
latent_model_input = latents
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
@@ -795,7 +795,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Simplification of implementation for now
|
||||
if not isinstance(prompt, str):
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
|
||||
if num_videos_per_prompt != 1:
|
||||
raise ValueError(
|
||||
|
||||
@@ -21,19 +21,20 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, logging
|
||||
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
Args:
|
||||
quant_type (`str`):
|
||||
quant_type (Union[`str`, AOBaseConfig]):
|
||||
The type of quantization we want to use, currently supporting:
|
||||
- **Integer quantization:**
|
||||
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
|
||||
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
- **Unsigned Integer quantization:**
|
||||
- Full function names: `uintx_weight_only`
|
||||
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
|
||||
- An AOBaseConfig instance: for more advanced configuration options.
|
||||
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision.
|
||||
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
```python
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# AOBaseConfig-based configuration
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
|
||||
# String-based config
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_type: str,
|
||||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -504,34 +512,103 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
else:
|
||||
self.quant_type_kwargs = kwargs
|
||||
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if not isinstance(self.quant_type, str):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
|
||||
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
|
||||
signature = inspect.signature(method)
|
||||
all_kwargs = {
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
}
|
||||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
|
||||
|
||||
if len(unsupported_kwargs) > 0:
|
||||
raise ValueError(
|
||||
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
|
||||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
|
||||
)
|
||||
elif isinstance(self.quant_type, str):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
|
||||
signature = inspect.signature(method)
|
||||
all_kwargs = {
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
}
|
||||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
|
||||
|
||||
if len(unsupported_kwargs) > 0:
|
||||
raise ValueError(
|
||||
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
|
||||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert configuration to a dictionary."""
|
||||
d = super().to_dict()
|
||||
|
||||
if isinstance(self.quant_type, str):
|
||||
# Handle layout serialization if present
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
if is_dataclass(d["quant_type_kwargs"]["layout"]):
|
||||
d["quant_type_kwargs"]["layout"] = [
|
||||
d["quant_type_kwargs"]["layout"].__class__.__name__,
|
||||
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
|
||||
]
|
||||
if isinstance(d["quant_type_kwargs"]["layout"], list):
|
||||
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
|
||||
else:
|
||||
raise ValueError("layout must be a list")
|
||||
else:
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
|
||||
# For now we assume there is 1 config per Transformer, however in the future
|
||||
# We may want to support a config per fqn.
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
# Check if we only have one key which is "default"
|
||||
# In the future we may update this
|
||||
assert len(quant_type) == 1 and "default" in quant_type, (
|
||||
"Expected only one key 'default' in quant_type dictionary"
|
||||
)
|
||||
quant_type = quant_type["default"]
|
||||
|
||||
# Deserialize quant_type if needed
|
||||
from torchao.core.config import config_from_dict
|
||||
|
||||
quant_type = config_from_dict(quant_type)
|
||||
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
|
||||
@classmethod
|
||||
def _get_torchao_quant_type_to_method(cls):
|
||||
@@ -681,8 +758,38 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
|
||||
"""Create the appropriate quantization method based on configuration."""
|
||||
if not isinstance(self.quant_type, str):
|
||||
return self.quant_type
|
||||
else:
|
||||
methods = self._get_torchao_quant_type_to_method()
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
and quant_type_kwargs.get("layout", None) is None
|
||||
):
|
||||
if torch.xpu.is_available():
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
|
||||
"0.11.0"
|
||||
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
quant_type_kwargs["layout"] = Int4XPULayout()
|
||||
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
|
||||
)
|
||||
else:
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
|
||||
return methods[self.quant_type](**quant_type_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
r"""
|
||||
|
||||
@@ -18,9 +18,10 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import re
|
||||
import types
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -107,6 +108,21 @@ if (
|
||||
_update_torch_safe_globals()
|
||||
|
||||
|
||||
def fuzzy_match_size(config_name: str) -> Optional[str]:
|
||||
"""
|
||||
Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
|
||||
None.
|
||||
"""
|
||||
config_name = config_name.lower()
|
||||
|
||||
str_match = re.search(r"(\d)weight", config_name)
|
||||
|
||||
if str_match:
|
||||
return str_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -176,8 +192,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
|
||||
if quant_type.startswith("int") or quant_type.startswith("uint"):
|
||||
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
@@ -197,24 +212,44 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
quant_type = self.quantization_config.quant_type
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
if quant_type.startswith("int8") or quant_type.startswith("int4"):
|
||||
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
|
||||
return torch.int8
|
||||
elif quant_type == "uintx_weight_only":
|
||||
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
|
||||
elif quant_type.startswith("uint"):
|
||||
return {
|
||||
1: torch.uint1,
|
||||
2: torch.uint2,
|
||||
3: torch.uint3,
|
||||
4: torch.uint4,
|
||||
5: torch.uint5,
|
||||
6: torch.uint6,
|
||||
7: torch.uint7,
|
||||
}[int(quant_type[4])]
|
||||
elif quant_type.startswith("float") or quant_type.startswith("fp"):
|
||||
return torch.bfloat16
|
||||
if isinstance(quant_type, str):
|
||||
if quant_type.startswith("int8"):
|
||||
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
|
||||
return torch.int8
|
||||
elif quant_type.startswith("int4"):
|
||||
return CustomDtype.INT4
|
||||
elif quant_type == "uintx_weight_only":
|
||||
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
|
||||
elif quant_type.startswith("uint"):
|
||||
return {
|
||||
1: torch.uint1,
|
||||
2: torch.uint2,
|
||||
3: torch.uint3,
|
||||
4: torch.uint4,
|
||||
5: torch.uint5,
|
||||
6: torch.uint6,
|
||||
7: torch.uint7,
|
||||
}[int(quant_type[4])]
|
||||
elif quant_type.startswith("float") or quant_type.startswith("fp"):
|
||||
return torch.bfloat16
|
||||
|
||||
elif is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
# Map the extracted digit to appropriate dtype
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
# Default to int8
|
||||
return torch.int8
|
||||
|
||||
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
|
||||
return target_dtype
|
||||
@@ -297,6 +332,21 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
# For the uint types, this is a best guess. Once these types become more used
|
||||
# we can look into their nuances.
|
||||
if is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
|
||||
quant_type = self.quantization_config.quant_type
|
||||
for pattern, target_dtype in map_to_target_dtype.items():
|
||||
|
||||
@@ -648,6 +648,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ContextParallelConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1053,6 +1068,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ParallelConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1592,6 +1592,21 @@ class LTXPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LucyEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1877,6 +1892,21 @@ class QwenImageEditPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageEditPlusPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
|
||||
def get_cached_module_file(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
@@ -353,6 +354,7 @@ def get_cached_module_file(
|
||||
resolved_module_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -410,6 +412,7 @@ def get_cached_module_file(
|
||||
get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
f"{module_needed}.py",
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -424,6 +427,7 @@ def get_cached_module_file(
|
||||
def get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
class_name: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
|
||||
final_module = get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
|
||||
@@ -38,13 +38,13 @@ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
is_jinja_available,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from packaging import version
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
@@ -316,7 +316,7 @@ def _get_model_file(
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
) from e
|
||||
except HTTPError as e:
|
||||
except HfHubHTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
|
||||
) from e
|
||||
@@ -432,7 +432,7 @@ def _get_checkpoint_shard_files(
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HTTPError as e:
|
||||
except HfHubHTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
|
||||
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user