Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 92199ff3ac | |||
| 04e9323055 | |||
| 9a09162baf | |||
| 33a8a3be0c | |||
| 58743c3ee7 | |||
| 50c0b786d2 |
@@ -110,9 +110,8 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -116,9 +116,8 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -254,10 +253,9 @@ jobs:
|
||||
python -m uv pip install -e [quality,test]
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
# python -m uv pip install -U tokenizers
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -U tokenizers
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -132,9 +132,8 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -204,9 +203,8 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -268,8 +266,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
# Stopping this update temporarily until the Hub RC is fully shipped and integrated.
|
||||
# pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -23,7 +23,11 @@
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Schedulers
|
||||
title: Load schedulers and models
|
||||
- local: using-diffusers/models
|
||||
title: Models
|
||||
- local: using-diffusers/scheduler_features
|
||||
title: Scheduler features
|
||||
- local: using-diffusers/other-formats
|
||||
title: Model files and layouts
|
||||
- local: using-diffusers/push_to_hub
|
||||
@@ -64,14 +68,10 @@
|
||||
title: Accelerate inference
|
||||
- local: optimization/cache
|
||||
title: Caching
|
||||
- local: optimization/attention_backends
|
||||
title: Attention backends
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compiling and offloading quantized models
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
- title: Community optimizations
|
||||
sections:
|
||||
- local: optimization/pruna
|
||||
@@ -82,8 +82,6 @@
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
- local: optimization/cache_dit
|
||||
title: CacheDiT
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
- local: optimization/xdit
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Parallelism
|
||||
|
||||
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
|
||||
|
||||
## ParallelConfig
|
||||
|
||||
[[autodoc]] ParallelConfig
|
||||
|
||||
## ContextParallelConfig
|
||||
|
||||
[[autodoc]] ContextParallelConfig
|
||||
|
||||
[[autodoc]] hooks.apply_context_parallel
|
||||
@@ -26,7 +26,6 @@ Qwen-Image comes in the following variants:
|
||||
|:----------:|:--------:|
|
||||
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
|
||||
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
|
||||
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -97,29 +96,6 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
|
||||
|
||||
</Tip>
|
||||
|
||||
## Multi-image reference with QwenImageEditPlusPipeline
|
||||
|
||||
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
|
||||
|
||||
```
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt="put the penguin and the cat at a game show called "Qwen Edit Plus Games"",
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
@@ -150,15 +126,7 @@ image = pipe(
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageControlNetPipeline
|
||||
|
||||
[[autodoc]] QwenImageControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageEditPlusPipeline
|
||||
|
||||
[[autodoc]] QwenImageEditPlusPipeline
|
||||
## QwenImaggeControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Attention backends
|
||||
|
||||
> [!NOTE]
|
||||
> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
|
||||
|
||||
Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
|
||||
|
||||
Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
|
||||
|
||||
| attention family | main feature |
|
||||
|---|---|
|
||||
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
|
||||
| SageAttention | quantizes attention to int8 |
|
||||
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
|
||||
| xFormers | memory-efficient attention with support for various attention kernels |
|
||||
|
||||
This guide will show you how to set and use the different attention backends.
|
||||
|
||||
## set_attention_backend
|
||||
|
||||
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
|
||||
|
||||
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
|
||||
|
||||
> [!NOTE]
|
||||
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_flash_3_hub")
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.reset_attention_backend()
|
||||
```
|
||||
|
||||
## attention_backend context manager
|
||||
|
||||
The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
with attention_backend("_flash_3_hub"):
|
||||
image = pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
|
||||
|
||||
## Available backends
|
||||
|
||||
Refer to the table below for a complete list of available attention backends and their variants.
|
||||
|
||||
<details>
|
||||
<summary>Expand</summary>
|
||||
|
||||
| Backend Name | Family | Description |
|
||||
|--------------|--------|-------------|
|
||||
| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
|
||||
| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
|
||||
| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
|
||||
| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
|
||||
| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
|
||||
| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
|
||||
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
|
||||
|
||||
</details>
|
||||
@@ -1,270 +0,0 @@
|
||||
## CacheDiT
|
||||
|
||||
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
|
||||
|
||||
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
|
||||
|
||||
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="PyPI">
|
||||
|
||||
```bash
|
||||
pip3 install -U cache-dit
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="source">
|
||||
|
||||
```bash
|
||||
pip3 install git+https://github.com/vipshop/cache-dit.git
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Run the command below to view supported DiT pipelines.
|
||||
|
||||
```python
|
||||
>>> import cache_dit
|
||||
>>> cache_dit.supported_pipelines()
|
||||
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
|
||||
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
|
||||
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
|
||||
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
|
||||
```
|
||||
|
||||
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
|
||||
|
||||
|
||||
## Unified Cache API
|
||||
|
||||
CacheDiT works by matching specific input/output patterns as shown below.
|
||||
|
||||

|
||||
|
||||
Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
# Can be any diffusion pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
|
||||
|
||||
# One-line code with default cache options.
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Just call the pipe as normal.
|
||||
output = pipe(...)
|
||||
|
||||
# Disable cache and run original pipe.
|
||||
cache_dit.disable_cache(pipe)
|
||||
```
|
||||
|
||||
## Automatic Block Adapter
|
||||
|
||||
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
|
||||
|
||||
|
||||
```python
|
||||
from cache_dit import ForwardPattern, BlockAdapter
|
||||
|
||||
# Use 🔥BlockAdapter with `auto` mode.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
# Any DiffusionPipeline, Qwen-Image, etc.
|
||||
pipe=pipe, auto=True,
|
||||
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
||||
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
|
||||
# Or, manually setup transformer configurations.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # Qwen-Image, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=pipe.transformer.transformer_blocks,
|
||||
forward_pattern=ForwardPattern.Pattern_1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
|
||||
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
|
||||
|
||||
```python
|
||||
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
|
||||
# single_transformer_blocks have different forward patterns.
|
||||
cache_dit.enable_cache(
|
||||
BlockAdapter(
|
||||
pipe=pipe, # FLUX.1, etc.
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.transformer_blocks,
|
||||
pipe.transformer.single_transformer_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_1,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
|
||||
|
||||
## Patch Functor
|
||||
|
||||
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
|
||||
|
||||

|
||||
|
||||
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
|
||||
|
||||
```python
|
||||
@BlockAdapterRegistry.register("HiDream")
|
||||
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
||||
|
||||
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
||||
return BlockAdapter(
|
||||
pipe=pipe,
|
||||
transformer=pipe.transformer,
|
||||
blocks=[
|
||||
pipe.transformer.double_stream_blocks,
|
||||
pipe.transformer.single_stream_blocks,
|
||||
],
|
||||
forward_pattern=[
|
||||
ForwardPattern.Pattern_0,
|
||||
ForwardPattern.Pattern_3,
|
||||
],
|
||||
# NOTE: Setup your custom patch functor here.
|
||||
patch_functor=HiDreamPatchFunctor(),
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
|
||||
|
||||
```python
|
||||
stats = cache_dit.summary(pipe)
|
||||
```
|
||||
|
||||
```python
|
||||
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
|
||||
|
||||
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
|
||||
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
|
||||
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
|
||||
```
|
||||
|
||||
## DBCache: Dual Block Cache
|
||||
|
||||

|
||||
|
||||
DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
|
||||
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
||||
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
||||
|
||||
|
||||
```python
|
||||
import cache_dit
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe_or_adapter = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# Default options, F8B0, 8 warmup steps, and unlimited cached
|
||||
# steps for good balance between performance and precision
|
||||
cache_dit.enable_cache(pipe_or_adapter)
|
||||
|
||||
# Custom options, F8B8, higher precision
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
)
|
||||
```
|
||||
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
|
||||
|
||||
## TaylorSeer Calibrator
|
||||
|
||||
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
||||
|
||||
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
# Basic DBCache w/ FnBn configurations
|
||||
cache_config=BasicCacheConfig(
|
||||
max_warmup_steps=8, # steps do not cache
|
||||
max_cached_steps=-1, # -1 means no limit
|
||||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||||
residual_diff_threshold=0.12,
|
||||
),
|
||||
# Then, you can use the TaylorSeer Calibrator to approximate
|
||||
# the values in cached steps, taylorseer_order default is 1.
|
||||
calibrator_config=TaylorSeerCalibratorConfig(
|
||||
taylorseer_order=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
|
||||
|
||||
## Hybrid Cache CFG
|
||||
|
||||
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
|
||||
|
||||
```python
|
||||
from cache_dit import BasicCacheConfig
|
||||
|
||||
cache_dit.enable_cache(
|
||||
pipe_or_adapter,
|
||||
cache_config=BasicCacheConfig(
|
||||
...,
|
||||
# For example, set it as True for Wan 2.1, Qwen-Image
|
||||
# and set it as False for FLUX.1, HunyuanVideo, etc.
|
||||
enable_separate_cfg=True,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
|
||||
|
||||
|
||||
```python
|
||||
cache_dit.enable_cache(pipe)
|
||||
|
||||
# Compile the Transformer module
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
```
|
||||
|
||||
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
|
||||
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit = 96 # default is 8
|
||||
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
||||
```
|
||||
|
||||
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
|
||||
@@ -0,0 +1,120 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Models
|
||||
|
||||
A diffusion model relies on a few individual models working together to generate an output. These models are responsible for denoising, encoding inputs, and decoding latents into the actual outputs.
|
||||
|
||||
This guide will show you how to load models.
|
||||
|
||||
## Loading a model
|
||||
|
||||
All models are loaded with the [`~ModelMixin.from_pretrained`] method, which downloads and caches the latest model version. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache.
|
||||
|
||||
Pass the `subfolder` argument to [`~ModelMixin.from_pretrained`] to specify where to load the model weights from. Omit the `subfolder` argument if the repository doesn't have a subfolder structure or if you're loading a standalone model.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
```
|
||||
|
||||
## AutoModel
|
||||
|
||||
[`AutoModel`] detects the model class from a `model_index.json` file or a model's `config.json` file. It fetches the correct model class from these files and delegates the actual loading to the model class. [`AutoModel`] is useful for automatic model type detection without needing to know the exact model class beforehand.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
```
|
||||
|
||||
## Model data types
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to load a model with a specific data type. This allows you to load a model in a lower precision to reduce memory usage.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
[nn.Module.to](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to) can also convert to a specific data type on the fly. However, it converts *all* weights to the requested data type unlike `torch_dtype` which respects `_keep_in_fp32_modules`. This argument preserves layers in `torch.float32` for numerical stability and best generation quality (see example [_keep_in_fp32_modules](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer"
|
||||
)
|
||||
model = model.to(dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Device placement
|
||||
|
||||
Use the `device_map` argument in [`~ModelMixin.from_pretrained`] to place a model on an accelerator like a GPU. It is especially helpful where there are multiple GPUs.
|
||||
|
||||
Diffusers currently provides three options to `device_map` for individual models, `"cuda"`, `"balanced"` and `"auto"`. Refer to the table below to compare the three placement strategies.
|
||||
|
||||
| parameter | description |
|
||||
|---|---|
|
||||
| `"cuda"` | places pipeline on a supported accelerator (CUDA) |
|
||||
| `"balanced"` | evenly distributes pipeline on all GPUs |
|
||||
| `"auto"` | distribute model from fastest device first to slowest |
|
||||
|
||||
Use the `max_memory` argument in [`~ModelMixin.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
max_memory = {0: "16GB", 1: "16GB"}
|
||||
pipeline = QwenImagePipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
max_memory=max_memory
|
||||
)
|
||||
```
|
||||
|
||||
The `hf_device_map` attribute allows you to access and view the `device_map`.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
# {'': device(type='cuda')}
|
||||
```
|
||||
|
||||
## Saving models
|
||||
|
||||
Save a model with the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```py
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
|
||||
model.save_pretrained("./local/model")
|
||||
```
|
||||
|
||||
For large models, it is helpful to use `max_shard_size` to save a model as multiple shards. A shard can be loaded faster and save memory (refer to the [parallel loading](./loading#parallel-loading) docs for more details), especially if there is more than one GPU.
|
||||
|
||||
```py
|
||||
model.save_pretrained("./local/model", max_shard_size="5GB")
|
||||
```
|
||||
@@ -0,0 +1,235 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Scheduler features
|
||||
|
||||
The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
|
||||
|
||||
This guide will demonstrate how to use these features to improve inference quality.
|
||||
|
||||
> [!TIP]
|
||||
> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
|
||||
|
||||
For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
|
||||
|
||||
```py
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
```
|
||||
|
||||
You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
|
||||
|
||||
- `leading` creates evenly spaced steps
|
||||
- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
|
||||
- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
|
||||
|
||||
It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Sigmas
|
||||
|
||||
The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
|
||||
|
||||
For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "anthropomorphic capybara wearing a suit and working with a computer"
|
||||
generator = torch.Generator(device='cuda').manual_seed(123)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=10,
|
||||
sigmas=sigmas,
|
||||
generator=generator
|
||||
).images[0]
|
||||
```
|
||||
|
||||
When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
|
||||
|
||||
```py
|
||||
print(f" timesteps: {pipe.scheduler.timesteps}")
|
||||
"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
|
||||
```
|
||||
|
||||
### Karras sigmas
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
|
||||
>
|
||||
> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
|
||||
|
||||
Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
|
||||
|
||||
Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
generator = torch.Generator(device="cpu").manual_seed(2487854446)
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Rescale noise schedule
|
||||
|
||||
In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
|
||||
|
||||
> [!TIP]
|
||||
> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
|
||||
>
|
||||
> ```bash
|
||||
> --prediction_type="v_prediction"
|
||||
> ```
|
||||
|
||||
For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
|
||||
|
||||
* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
|
||||
* `timestep_spacing="trailing"` to start sampling from the last timestep
|
||||
|
||||
Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
|
||||
generator = torch.Generator(device="cpu").manual_seed(23)
|
||||
image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
@@ -10,273 +10,200 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Load schedulers and models
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Schedulers
|
||||
Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
|
||||
|
||||
A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
|
||||
|
||||
Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
|
||||
|
||||
This guide will show you how to load and customize schedulers.
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
|
||||
This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
|
||||
You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
|
||||
|
||||
```py
|
||||
pipeline.scheduler
|
||||
PNDMScheduler {
|
||||
"_class_name": "PNDMScheduler",
|
||||
"_diffusers_version": "0.21.4",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": null
|
||||
}
|
||||
```
|
||||
|
||||
## Load a scheduler
|
||||
|
||||
Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
|
||||
|
||||
For example, to load the [`DDIMScheduler`]:
|
||||
|
||||
```py
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
```
|
||||
|
||||
Then you can pass the newly loaded scheduler to the pipeline.
|
||||
|
||||
```python
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
|
||||
|
||||
Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
```
|
||||
|
||||
To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
|
||||
|
||||
<hfoptions id="schedulers">
|
||||
<hfoption id="LMSDiscreteScheduler">
|
||||
|
||||
[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
|
||||
|
||||
```py
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerDiscreteScheduler">
|
||||
|
||||
[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="EulerAncestralDiscreteScheduler">
|
||||
|
||||
[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
|
||||
|
||||
```py
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPMSolverMultistepScheduler">
|
||||
|
||||
[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
|
||||
|
||||
```py
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
scheduler=dpm,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler
|
||||
```
|
||||
|
||||
## Timestep schedules
|
||||
|
||||
Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
|
||||
|
||||
> [!TIP]
|
||||
> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
|
||||
|
||||
Import the schedule and pass it to the `timesteps` argument in the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from diffusers.schedulers import AysSchedules
|
||||
|
||||
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
|
||||
print(sampling_schedule)
|
||||
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
timesteps=sampling_schedule,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Rescaling schedules
|
||||
|
||||
Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
|
||||
|
||||
> [!TIP]
|
||||
> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
|
||||
|
||||
To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
|
||||
|
||||
- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
|
||||
- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(
|
||||
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
||||
)
|
||||
```
|
||||
|
||||
Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
|
||||
|
||||
```py
|
||||
prompt = """
|
||||
cinematic photo of a snowy mountain at night with the northern lights aurora borealis
|
||||
overhead, 35mm photograph, film, professional, 4k, highly detailed
|
||||
"""
|
||||
image = pipeline(prompt, guidance_rescale=0.7).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Timestep spacing
|
||||
|
||||
Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
|
||||
|
||||
| spacing strategy | spacing calculation | example timesteps |
|
||||
|---|---|---|
|
||||
| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
|
||||
| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
|
||||
| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
|
||||
|
||||
Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
|
||||
|
||||
> [!TIP]
|
||||
> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
num_inference_steps=5,
|
||||
).images[0]
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Sigmas
|
||||
|
||||
Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
|
||||
|
||||
> [!TIP]
|
||||
> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
|
||||
|
||||
Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Karras sigmas
|
||||
|
||||
[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
|
||||
|
||||
Set `use_karras_sigmas=True` in the scheduler to enable it.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config,
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
|
||||
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
sigmas=sigmas,
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">LMSDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerAncestralDiscreteScheduler</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">DPMSolverMultistepScheduler</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
|
||||
Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
|
||||
|
||||
## Choosing a scheduler
|
||||
## Models
|
||||
|
||||
It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
|
||||
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
|
||||
|
||||
- DPM++ 2M SDE Karras is generally a good all-purpose option.
|
||||
- [`TCDScheduler`] works well for distilled models.
|
||||
- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
|
||||
- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
|
||||
- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
|
||||
Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
|
||||
|
||||
## Resources
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
|
||||
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
|
||||
```
|
||||
|
||||
They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
|
||||
```
|
||||
|
||||
To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
|
||||
)
|
||||
unet.save_pretrained("./local-unet", variant="non_ema")
|
||||
```
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetModels:
|
||||
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
||||
SD3_5: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TextToImagePipelineSD3:
|
||||
def __init__(self, model_path: str | None = None):
|
||||
self.model_path = model_path or os.getenv("MODEL_PATH")
|
||||
self.pipeline: StableDiffusion3Pipeline | None = None
|
||||
self.device: str | None = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
class ModelPipelineInitializer:
|
||||
def __init__(self, model: str = "", type_models: str = "t2im"):
|
||||
self.model = model
|
||||
self.type_models = type_models
|
||||
self.pipeline = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
self.model_type = None
|
||||
|
||||
def initialize_pipeline(self):
|
||||
if not self.model:
|
||||
raise ValueError("Model name not provided")
|
||||
|
||||
# Check if model exists in PresetModels
|
||||
preset_models = PresetModels()
|
||||
|
||||
# Determine which model type we're dealing with
|
||||
if self.model in preset_models.SD3:
|
||||
self.model_type = "SD3"
|
||||
elif self.model in preset_models.SD3_5:
|
||||
self.model_type = "SD3_5"
|
||||
|
||||
# Create appropriate pipeline based on model type and type_models
|
||||
if self.type_models == "t2im":
|
||||
if self.model_type in ["SD3", "SD3_5"]:
|
||||
self.pipeline = TextToImagePipelineSD3(self.model)
|
||||
else:
|
||||
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
||||
elif self.type_models == "t2v":
|
||||
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
||||
|
||||
return self.pipeline
|
||||
@@ -1,171 +0,0 @@
|
||||
# Asynchronous server and parallel execution of models
|
||||
|
||||
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
|
||||
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
|
||||
|
||||
## ⚠️ IMPORTANT
|
||||
|
||||
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
|
||||
|
||||
## Necessary components
|
||||
|
||||
All the components needed to create the inference server are in the current directory:
|
||||
|
||||
```
|
||||
server-async/
|
||||
├── utils/
|
||||
├─────── __init__.py
|
||||
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
|
||||
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
|
||||
├─────── utils.py # Image/video saving utilities and service configuration
|
||||
├── Pipelines.py # pipeline loader classes (SD3)
|
||||
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
|
||||
├── test.py # Client test script for inference requests
|
||||
├── requirements.txt # Dependencies
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## What `diffusers-async` adds / Why we needed it
|
||||
|
||||
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
|
||||
|
||||
`diffusers-async` / this example addresses that by:
|
||||
|
||||
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
|
||||
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
|
||||
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
|
||||
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
|
||||
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
|
||||
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
|
||||
|
||||
## How the server works (high-level flow)
|
||||
|
||||
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
|
||||
2. On each HTTP inference request:
|
||||
|
||||
* The server uses `RequestScopedPipeline.generate(...)` which:
|
||||
|
||||
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
|
||||
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
|
||||
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
|
||||
* sets `local_pipe.scheduler = local_scheduler` (if possible),
|
||||
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
|
||||
* wraps tokenizers with thread-safe locks to prevent race conditions,
|
||||
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
|
||||
* calls the pipeline on the local view (`local_pipe(...)`).
|
||||
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
|
||||
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
|
||||
|
||||
## How to set up and run the server
|
||||
|
||||
### 1) Install dependencies
|
||||
|
||||
Recommended: create a virtualenv / conda environment.
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2) Start the server
|
||||
|
||||
Using the `serverasync.py` file that already has everything you need:
|
||||
|
||||
```bash
|
||||
python serverasync.py
|
||||
```
|
||||
|
||||
The server will start on `http://localhost:8500` by default with the following features:
|
||||
- FastAPI application with async lifespan management
|
||||
- Automatic model loading and pipeline initialization
|
||||
- Request counting and active inference tracking
|
||||
- Memory cleanup after each inference
|
||||
- CORS middleware for cross-origin requests
|
||||
|
||||
### 3) Test the server
|
||||
|
||||
Use the included test script:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
Or send a manual request:
|
||||
|
||||
`POST /api/diffusers/inference` with JSON body:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "A futuristic cityscape, vibrant colors",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
|
||||
```json
|
||||
{
|
||||
"response": ["http://localhost:8500/images/img123.png"]
|
||||
}
|
||||
```
|
||||
|
||||
### 4) Server endpoints
|
||||
|
||||
- `GET /` - Welcome message
|
||||
- `POST /api/diffusers/inference` - Main inference endpoint
|
||||
- `GET /images/{filename}` - Serve generated images
|
||||
- `GET /api/status` - Server status and memory info
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### RequestScopedPipeline Parameters
|
||||
|
||||
```python
|
||||
RequestScopedPipeline(
|
||||
pipeline, # Base pipeline to wrap
|
||||
mutable_attrs=None, # Custom list of attributes to clone
|
||||
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
|
||||
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
|
||||
tokenizer_lock=None, # Custom threading lock for tokenizers
|
||||
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
|
||||
)
|
||||
```
|
||||
|
||||
### BaseAsyncScheduler Features
|
||||
|
||||
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
|
||||
* `clone_for_request()` method for safe per-request scheduler cloning
|
||||
* Enhanced debugging with `__repr__` and `__str__` methods
|
||||
* Full compatibility with existing scheduler APIs
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = 'stabilityai/stable-diffusion-3.5-medium'
|
||||
type_models: str = 't2im'
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8500
|
||||
```
|
||||
|
||||
## Troubleshooting (quick)
|
||||
|
||||
* `Already borrowed` — previously a Rust tokenizer concurrency error.
|
||||
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
|
||||
|
||||
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
|
||||
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
|
||||
|
||||
* Scheduler issues:
|
||||
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
|
||||
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
|
||||
|
||||
* Memory issues with large tensors:
|
||||
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
|
||||
|
||||
* Automatic tokenizer detection:
|
||||
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
|
||||
@@ -1,10 +0,0 @@
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
sentencepiece
|
||||
fastapi
|
||||
uvicorn
|
||||
ftfy
|
||||
accelerate
|
||||
xformers
|
||||
protobuf
|
||||
@@ -1,230 +0,0 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from Pipelines import ModelPipelineInitializer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import RequestScopedPipeline, Utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
||||
type_models: str = "t2im"
|
||||
constructor_pipeline: Optional[Type] = None
|
||||
custom_pipeline: Optional[Type] = None
|
||||
components: Optional[Dict[str, Any]] = None
|
||||
torch_dtype: Optional[torch.dtype] = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8500
|
||||
|
||||
|
||||
server_config = ServerConfigModels()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
app.state.logger = logging.getLogger("diffusers-server")
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
|
||||
app.state.total_requests = 0
|
||||
app.state.active_inferences = 0
|
||||
app.state.metrics_lock = asyncio.Lock()
|
||||
app.state.metrics_task = None
|
||||
|
||||
app.state.utils_app = Utils(
|
||||
host=server_config.host,
|
||||
port=server_config.port,
|
||||
)
|
||||
|
||||
async def metrics_loop():
|
||||
try:
|
||||
while True:
|
||||
async with app.state.metrics_lock:
|
||||
total = app.state.total_requests
|
||||
active = app.state.active_inferences
|
||||
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
app.state.logger.info("Metrics loop cancelled")
|
||||
raise
|
||||
|
||||
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task = app.state.metrics_task
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
||||
if callable(stop_fn):
|
||||
await run_in_threadpool(stop_fn)
|
||||
except Exception as e:
|
||||
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
||||
|
||||
app.state.logger.info("Lifespan shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger = logging.getLogger("DiffusersServer.Pipelines")
|
||||
|
||||
|
||||
initializer = ModelPipelineInitializer(
|
||||
model=server_config.model,
|
||||
type_models=server_config.type_models,
|
||||
)
|
||||
model_pipeline = initializer.initialize_pipeline()
|
||||
model_pipeline.start()
|
||||
|
||||
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
||||
pipeline_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
||||
|
||||
app.state.MODEL_INITIALIZER = initializer
|
||||
app.state.MODEL_PIPELINE = model_pipeline
|
||||
app.state.REQUEST_PIPE = request_pipe
|
||||
app.state.PIPELINE_LOCK = pipeline_lock
|
||||
|
||||
|
||||
class JSONBodyQueryAPI(BaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
num_inference_steps: int = 28
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def count_requests_middleware(request: Request, call_next):
|
||||
async with app.state.metrics_lock:
|
||||
app.state.total_requests += 1
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Diffusers Server"}
|
||||
|
||||
|
||||
@app.post("/api/diffusers/inference")
|
||||
async def api(json: JSONBodyQueryAPI):
|
||||
prompt = json.prompt
|
||||
negative_prompt = json.negative_prompt or ""
|
||||
num_steps = json.num_inference_steps
|
||||
num_images_per_prompt = json.num_images_per_prompt
|
||||
|
||||
wrapper = app.state.MODEL_PIPELINE
|
||||
initializer = app.state.MODEL_INITIALIZER
|
||||
|
||||
utils_app = app.state.utils_app
|
||||
|
||||
if not wrapper or not wrapper.pipeline:
|
||||
raise HTTPException(500, "Model not initialized correctly")
|
||||
if not prompt.strip():
|
||||
raise HTTPException(400, "No prompt provided")
|
||||
|
||||
def make_generator():
|
||||
g = torch.Generator(device=initializer.device)
|
||||
return g.manual_seed(random.randint(0, 10_000_000))
|
||||
|
||||
req_pipe = app.state.REQUEST_PIPE
|
||||
|
||||
def infer():
|
||||
gen = make_generator()
|
||||
return req_pipe.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=gen,
|
||||
num_inference_steps=num_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=initializer.device,
|
||||
output_type="pil",
|
||||
)
|
||||
|
||||
try:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences += 1
|
||||
|
||||
output = await run_in_threadpool(infer)
|
||||
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
|
||||
urls = [utils_app.save_image(img) for img in output.images]
|
||||
return {"response": urls}
|
||||
|
||||
except Exception as e:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise HTTPException(500, f"Error in processing: {e}")
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@app.get("/images/{filename}")
|
||||
async def serve_image(filename: str):
|
||||
utils_app = app.state.utils_app
|
||||
file_path = os.path.join(utils_app.image_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
memory_info = {}
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
||||
memory_info = {
|
||||
"memory_allocated_gb": round(memory_allocated, 2),
|
||||
"memory_reserved_gb": round(memory_reserved, 2),
|
||||
"device": torch.cuda.get_device_name(0),
|
||||
}
|
||||
|
||||
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=server_config.host, port=server_config.port)
|
||||
@@ -1,65 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
DOWNLOAD_FOLDER = "generated_images"
|
||||
WAIT_BEFORE_DOWNLOAD = 2 # seconds
|
||||
|
||||
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def save_from_url(url: str) -> str:
|
||||
"""Download the given URL (relative or absolute) and save it locally."""
|
||||
if url.startswith("/"):
|
||||
direct = BASE_URL.rstrip("/") + url
|
||||
else:
|
||||
direct = url
|
||||
resp = requests.get(direct, timeout=60)
|
||||
resp.raise_for_status()
|
||||
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
|
||||
path = os.path.join(DOWNLOAD_FOLDER, filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return path
|
||||
|
||||
|
||||
def main():
|
||||
payload = {
|
||||
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1,
|
||||
}
|
||||
|
||||
print("Sending request...")
|
||||
try:
|
||||
r = requests.post(SERVER_URL, json=payload, timeout=480)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return
|
||||
|
||||
body = r.json().get("response", [])
|
||||
# Normalize to a list
|
||||
urls = body if isinstance(body, list) else [body] if body else []
|
||||
if not urls:
|
||||
print("No URLs found in the response. Check the server output.")
|
||||
return
|
||||
|
||||
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
|
||||
time.sleep(WAIT_BEFORE_DOWNLOAD)
|
||||
|
||||
for u in urls:
|
||||
try:
|
||||
path = save_from_url(u)
|
||||
print(f"Image saved to: {path}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {u}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,2 +0,0 @@
|
||||
from .requestscopedpipeline import RequestScopedPipeline
|
||||
from .utils import Utils
|
||||
@@ -1,296 +0,0 @@
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
"_offload_device",
|
||||
"_progress_bar_config",
|
||||
"_progress_bar",
|
||||
"_rng_state",
|
||||
"_last_seed",
|
||||
"latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: Any,
|
||||
mutable_attrs: Optional[Iterable[str]] = None,
|
||||
auto_detect_mutables: bool = True,
|
||||
tensor_numel_threshold: int = 1_000_000,
|
||||
tokenizer_lock: Optional[threading.Lock] = None,
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
return None
|
||||
|
||||
if not isinstance(base_sched, BaseAsyncScheduler):
|
||||
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
||||
else:
|
||||
wrapped_scheduler = base_sched
|
||||
|
||||
try:
|
||||
return wrapped_scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
if not self._auto_detect_mutables:
|
||||
return []
|
||||
|
||||
if self._auto_detected_attrs:
|
||||
return self._auto_detected_attrs
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
if name in self._mutable_attrs:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
# try Tensor detection
|
||||
try:
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(candidates) >= max_attrs:
|
||||
break
|
||||
|
||||
self._auto_detected_attrs = candidates
|
||||
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
||||
return self._auto_detected_attrs
|
||||
|
||||
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
||||
try:
|
||||
cls = type(base_obj)
|
||||
descriptor = getattr(cls, attr_name, None)
|
||||
if isinstance(descriptor, property):
|
||||
return descriptor.fset is None
|
||||
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _clone_mutable_attrs(self, base, local):
|
||||
attrs_to_clone = list(self._mutable_attrs)
|
||||
attrs_to_clone.extend(self._autodetect_mutables())
|
||||
|
||||
EXCLUDE_ATTRS = {
|
||||
"components",
|
||||
}
|
||||
|
||||
for attr in attrs_to_clone:
|
||||
if attr in EXCLUDE_ATTRS:
|
||||
logger.debug(f"Skipping excluded attr '{attr}'")
|
||||
continue
|
||||
if not hasattr(base, attr):
|
||||
continue
|
||||
if self._is_readonly_property(base, attr):
|
||||
logger.debug(f"Skipping read-only property '{attr}'")
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(base, attr)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
setattr(local, attr, dict(val))
|
||||
elif isinstance(val, (list, tuple, set)):
|
||||
setattr(local, attr, list(val))
|
||||
elif isinstance(val, bytearray):
|
||||
setattr(local, attr, bytearray(val))
|
||||
else:
|
||||
# small tensors or atomic values
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
setattr(local, attr, val.clone())
|
||||
else:
|
||||
# don't clone big tensors, keep reference
|
||||
setattr(local, attr, val)
|
||||
else:
|
||||
try:
|
||||
setattr(local, attr, copy.copy(val))
|
||||
except Exception:
|
||||
setattr(local, attr, val)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
||||
continue
|
||||
|
||||
def _is_tokenizer_component(self, component) -> bool:
|
||||
if component is None:
|
||||
return False
|
||||
|
||||
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
||||
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
||||
|
||||
class_name = component.__class__.__name__.lower()
|
||||
has_tokenizer_in_name = "tokenizer" in class_name
|
||||
|
||||
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
||||
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
try:
|
||||
local_pipe = copy.copy(self._base)
|
||||
except Exception as e:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
local_scheduler.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
return_scheduler=True,
|
||||
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
||||
)
|
||||
|
||||
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
||||
setattr(local_pipe, "scheduler", final_scheduler)
|
||||
except Exception:
|
||||
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except Exception as e:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
@@ -1,141 +0,0 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseAsyncScheduler:
|
||||
def __init__(self, scheduler: Any):
|
||||
self.scheduler = scheduler
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if hasattr(self.scheduler, name):
|
||||
return getattr(self.scheduler, name)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
if name == "scheduler":
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
|
||||
setattr(self.scheduler, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
|
||||
local = copy.deepcopy(self.scheduler)
|
||||
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
|
||||
cloned = self.__class__(local)
|
||||
return cloned
|
||||
|
||||
def __repr__(self):
|
||||
return f"BaseAsyncScheduler({repr(self.scheduler)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
|
||||
|
||||
|
||||
def async_retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
|
||||
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Backwards compatible: by default the function behaves exactly as before and returns
|
||||
(timesteps_tensor, num_inference_steps)
|
||||
|
||||
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
|
||||
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
|
||||
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
|
||||
(timesteps_tensor, num_inference_steps, scheduler_in_use)
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Optional kwargs:
|
||||
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
|
||||
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
|
||||
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
|
||||
|
||||
Returns:
|
||||
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
|
||||
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
|
||||
"""
|
||||
# pop our optional control kwarg (keeps compatibility)
|
||||
return_scheduler = bool(kwargs.pop("return_scheduler", False))
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
|
||||
# choose scheduler to call set_timesteps on
|
||||
scheduler_in_use = scheduler
|
||||
if return_scheduler:
|
||||
# Do not mutate the provided scheduler: prefer to clone if possible
|
||||
if hasattr(scheduler, "clone_for_request"):
|
||||
try:
|
||||
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
|
||||
scheduler_in_use = scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps or 0, device=device
|
||||
)
|
||||
except Exception:
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
else:
|
||||
# fallback deepcopy (scheduler tends to be smallish - acceptable)
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
|
||||
# helper to test if set_timesteps supports a particular kwarg
|
||||
def _accepts(param_name: str) -> bool:
|
||||
try:
|
||||
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
|
||||
except (ValueError, TypeError):
|
||||
# if signature introspection fails, be permissive and attempt the call later
|
||||
return False
|
||||
|
||||
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = _accepts("timesteps")
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = _accepts("sigmas")
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
else:
|
||||
# default path
|
||||
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
|
||||
if return_scheduler:
|
||||
return timesteps_out, num_inference_steps, scheduler_in_use
|
||||
return timesteps_out, num_inference_steps
|
||||
@@ -1,48 +0,0 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Utils:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
||||
self.service_url = f"http://{host}:{port}"
|
||||
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(self.image_dir):
|
||||
os.makedirs(self.image_dir)
|
||||
|
||||
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
||||
if not os.path.exists(self.video_dir):
|
||||
os.makedirs(self.video_dir)
|
||||
|
||||
def save_image(self, image):
|
||||
if hasattr(image, "to"):
|
||||
try:
|
||||
image = image.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
from torchvision import transforms
|
||||
|
||||
to_pil = transforms.ToPILImage()
|
||||
image = to_pil(image.squeeze(0).clamp(0, 1))
|
||||
|
||||
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(self.image_dir, filename)
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
|
||||
image.save(image_path, format="PNG", optimize=True)
|
||||
|
||||
del image
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return os.path.join(self.service_url, "images", filename)
|
||||
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
@@ -6,5 +6,4 @@ py-consul
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
accelerate
|
||||
uvicorn
|
||||
@@ -39,7 +39,7 @@ fsspec==2024.10.0
|
||||
# torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.35.0
|
||||
huggingface-hub==0.26.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
|
||||
@@ -278,29 +278,6 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-VACE-Fun-14B":
|
||||
config = {
|
||||
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
@@ -998,17 +975,7 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "Wan2.2-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.875,
|
||||
)
|
||||
elif "Wan-VACE" in args.model_type:
|
||||
elif "VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -102,8 +102,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"httpx<1.0.0",
|
||||
"huggingface-hub>=0.34.0,<2.0",
|
||||
"huggingface-hub>=0.34.0",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
@@ -260,7 +259,6 @@ extras["dev"] = (
|
||||
install_requires = [
|
||||
deps["importlib_metadata"],
|
||||
deps["filelock"],
|
||||
deps["httpx"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["regex"],
|
||||
|
||||
@@ -202,7 +202,6 @@ else:
|
||||
"CogView4Transformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ContextParallelConfig",
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
@@ -230,7 +229,6 @@ else:
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageControlNetModel",
|
||||
@@ -497,7 +495,6 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
@@ -517,7 +514,6 @@ else:
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImagePipeline",
|
||||
@@ -890,7 +886,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Transformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
ConsistencyDecoderVAE,
|
||||
ContextParallelConfig,
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
@@ -918,7 +913,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageControlNetModel,
|
||||
@@ -1155,7 +1149,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
@@ -1175,7 +1168,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -30,11 +30,11 @@ import numpy as np
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import __version__
|
||||
@@ -419,7 +419,7 @@ class ConfigMixin:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HfHubHTTPError as err:
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
|
||||
@@ -9,8 +9,7 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"httpx": "httpx<1.0.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
|
||||
@@ -16,7 +16,6 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .context_parallel import apply_context_parallel
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..models._modeling_parallel import (
|
||||
ContextParallelConfig,
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
||||
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
|
||||
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier], True, None
|
||||
|
||||
if self.cached_parameter_indices is not None:
|
||||
index = self.cached_parameter_indices.get(identifier, None)
|
||||
if index is None:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
|
||||
return args[index], False, index
|
||||
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
|
||||
if identifier not in self.cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
|
||||
index = self.cached_parameter_indices[identifier]
|
||||
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
|
||||
return args[index], False, index
|
||||
|
||||
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ContextParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
||||
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
|
||||
|
||||
for m in submodule:
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
if isinstance(cp_model_plan, ContextParallelOutput):
|
||||
cp_model_plan = [cp_model_plan]
|
||||
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
|
||||
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
|
||||
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
for m in submodule:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry.remove_hook(hook_name)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
self.module_forward_metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
cls = unwrap_module(module).__class__
|
||||
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
args_list = list(args)
|
||||
|
||||
for name, cpm in self.metadata.items():
|
||||
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
|
||||
continue
|
||||
|
||||
# Maybe the parameter was passed as a keyword argument
|
||||
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
|
||||
name, args_list, kwargs
|
||||
)
|
||||
|
||||
if input_val is None:
|
||||
continue
|
||||
|
||||
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
|
||||
# the output instead of input for a particular layer by setting split_output=True
|
||||
if isinstance(input_val, torch.Tensor):
|
||||
input_val = self._prepare_cp_input(input_val, cpm)
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
if len(input_val) != len(cpm):
|
||||
raise ValueError(
|
||||
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
|
||||
)
|
||||
sharded_input_val = []
|
||||
for i, x in enumerate(input_val):
|
||||
if torch.is_tensor(x) and not cpm[i].split_output:
|
||||
x = self._prepare_cp_input(x, cpm[i])
|
||||
sharded_input_val.append(x)
|
||||
input_val = sharded_input_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(input_val)}")
|
||||
|
||||
if is_kwarg:
|
||||
kwargs[name] = input_val
|
||||
elif index is not None and index < len(args_list):
|
||||
args_list[index] = input_val
|
||||
else:
|
||||
raise ValueError(
|
||||
f"An unexpected error occurred while processing the input '{name}'. Please open an "
|
||||
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
|
||||
f"example along with the full stack trace."
|
||||
)
|
||||
|
||||
return tuple(args_list), kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
|
||||
|
||||
if not is_tensor and not is_tensor_list:
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = [output] if is_tensor else list(output)
|
||||
for index, cpm in self.metadata.items():
|
||||
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
|
||||
continue
|
||||
if index >= len(output):
|
||||
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
|
||||
current_output = output[index]
|
||||
current_output = self._prepare_cp_input(current_output, cpm)
|
||||
output[index] = current_output
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def post_forward(self, module, output):
|
||||
is_tensor = isinstance(output, torch.Tensor)
|
||||
|
||||
if is_tensor:
|
||||
output = [output]
|
||||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
|
||||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
||||
|
||||
output = list(output)
|
||||
|
||||
if len(output) != len(self.metadata):
|
||||
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
|
||||
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
|
||||
class AllGatherFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, dim, group):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = torch.distributed.get_world_size(group)
|
||||
ctx.rank = torch.distributed.get_rank(group)
|
||||
return funcol.all_gather_tensor(tensor, dim, group=group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_chunks[ctx.rank], None, None
|
||||
|
||||
|
||||
class EquipartitionSharder:
|
||||
@classmethod
|
||||
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
|
||||
# because the alternate case has not yet been tested/required for any model.
|
||||
assert tensor.size()[dim] % mesh.size() == 0, (
|
||||
"Tensor size along dimension to be sharded must be divisible by mesh size"
|
||||
)
|
||||
|
||||
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
|
||||
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
||||
|
||||
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
if first_atom == "*":
|
||||
if not isinstance(model, torch.nn.ModuleList):
|
||||
raise ValueError("Wildcard '*' can only be used with ModuleList")
|
||||
submodules = []
|
||||
for submodule in model:
|
||||
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
|
||||
if not isinstance(subsubmodules, list):
|
||||
subsubmodules = [subsubmodules]
|
||||
submodules.extend(subsubmodules)
|
||||
return submodules
|
||||
else:
|
||||
if hasattr(model, first_atom):
|
||||
submodule = getattr(model, first_atom)
|
||||
return _find_submodule_by_name(submodule, remaining_name)
|
||||
else:
|
||||
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
|
||||
@@ -1064,41 +1064,6 @@ class LoraBaseMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
|
||||
pipeline types.
|
||||
"""
|
||||
state_dict = {}
|
||||
final_lora_adapter_metadata = {}
|
||||
|
||||
for prefix, layers in lora_layers.items():
|
||||
state_dict.update(cls.pack_weights(layers, prefix))
|
||||
|
||||
for prefix, metadata in lora_metadata.items():
|
||||
if metadata:
|
||||
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@@ -558,62 +558,70 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
_convert_to_ai_toolkit(
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_in_layer",
|
||||
"time_text_embed.guidance_embedder.linear_1",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_out_layer",
|
||||
"time_text_embed.guidance_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
_convert_to_ai_toolkit(
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_img_in",
|
||||
"x_embedder",
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
_convert_to_ai_toolkit(
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_txt_in",
|
||||
"context_embedder",
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
_convert_to_ai_toolkit(
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_in_layer",
|
||||
"time_text_embed.timestep_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_out_layer",
|
||||
"time_text_embed.timestep_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
_convert_to_ai_toolkit(
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_in_layer",
|
||||
"time_text_embed.text_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_out_layer",
|
||||
"time_text_embed.text_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
|
||||
+2372
-266
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,6 @@ from ..utils import (
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
@@ -120,7 +119,6 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
|
||||
# Experimental changes are subject to change and APIs may break without warning.
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# TODO(aryan): add support for the following:
|
||||
# - Unified Attention
|
||||
# - More dispatcher attention backends
|
||||
# - CFG/Data Parallel
|
||||
# - Tensor Parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextParallelConfig:
|
||||
"""
|
||||
Configuration for context parallelism.
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
|
||||
"""
|
||||
|
||||
ring_degree: Optional[int] = None
|
||||
ulysses_degree: Optional[int] = None
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_ring_local_rank: int = None
|
||||
_ulysses_local_rank: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for applying different parallelisms.
|
||||
|
||||
Args:
|
||||
context_parallel_config (`ContextParallelConfig`, *optional*):
|
||||
Configuration for context parallelism.
|
||||
"""
|
||||
|
||||
context_parallel_config: Optional[ContextParallelConfig] = None
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelInput:
|
||||
"""
|
||||
Configuration for splitting an input tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
split_dim (`int`):
|
||||
The dimension along which to split the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before splitting.
|
||||
split_output (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
|
||||
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
|
||||
RoPE).
|
||||
"""
|
||||
|
||||
split_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
split_output: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextParallelOutput:
|
||||
"""
|
||||
Configuration for gathering an output tensor across context parallel region.
|
||||
|
||||
Args:
|
||||
gather_dim (`int`):
|
||||
The dimension along which to gather the tensor.
|
||||
expected_dims (`int`, *optional*):
|
||||
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
|
||||
tensor has the expected number of dimensions before gathering.
|
||||
"""
|
||||
|
||||
gather_dim: int
|
||||
expected_dims: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
|
||||
|
||||
|
||||
# A dictionary where keys denote the input to be split across context parallel region, and the
|
||||
# value denotes the sharding configuration.
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
|
||||
|
||||
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
||||
#
|
||||
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
|
||||
# tensors at different stages of the forward:
|
||||
#
|
||||
# ```python
|
||||
# _cp_plan = {
|
||||
# "": {
|
||||
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
# },
|
||||
# "pos_embed": {
|
||||
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
# },
|
||||
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
# }
|
||||
# ```
|
||||
#
|
||||
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
|
||||
# split/gathered according to this at the respective module level. Here, the following happens:
|
||||
# - "":
|
||||
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
|
||||
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
|
||||
# - "pos_embed":
|
||||
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
|
||||
# we can individually specify how they should be split
|
||||
# - "proj_out":
|
||||
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
|
||||
# layer forward has run).
|
||||
#
|
||||
# ContextParallelInput:
|
||||
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
|
||||
#
|
||||
# ContextParallelOutput:
|
||||
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
||||
@@ -17,10 +17,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
||||
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
@@ -31,6 +32,7 @@ ACT2CLS = {
|
||||
"gelu": nn.GELU,
|
||||
"relu": nn.ReLU,
|
||||
}
|
||||
KERNELS_REPO_ID = "kernels-community/activation"
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
@@ -90,6 +92,27 @@ class GELU(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# TODO: validation checks / consider making Python classes of activations like `transformers`
|
||||
# All of these are temporary for now.
|
||||
class CUDAOptimizedGELU(GELU):
|
||||
def __init__(self, *args, **kwargs):
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
approximate = kwargs.get("approximate", "none")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if approximate == "none":
|
||||
self.act_fn = activation.layers.Gelu()
|
||||
elif approximate == "tanh":
|
||||
self.act_fn = activation.layers.GeluTanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
||||
|
||||
@@ -241,7 +241,7 @@ class AttentionModuleMixin:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xops.ops.memory_efficient_attention(q, q, q)
|
||||
_ = xops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
):
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,6 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -115,8 +114,6 @@ class AutoModel(ConfigMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -143,22 +140,22 @@ class AutoModel(ConfigMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -192,35 +189,15 @@ class AutoModel(ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
model_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_or_path,
|
||||
subfolder=subfolder,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
@@ -617,7 +617,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.size(0) > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
is_residual=is_residual,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = scale_factor_spatial
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
@@ -1145,13 +1145,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor):
|
||||
_, _, num_frame, height, width = x.shape
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
|
||||
@@ -26,11 +26,11 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__, is_torch_available
|
||||
from ..utils import (
|
||||
@@ -385,7 +385,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
except HfHubHTTPError as err:
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
|
||||
@@ -65,7 +65,6 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -249,8 +248,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
_parallel_config = None
|
||||
_cp_plan = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -623,8 +620,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||
set, or the torch native scaled dot product attention.
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||
the torch native scaled dot product attention.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -963,7 +960,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1344,9 +1340,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if parallel_config is not None:
|
||||
model.enable_parallelism(config=parallel_config)
|
||||
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
@@ -1485,73 +1478,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
def enable_parallelism(
|
||||
self,
|
||||
*,
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_parallel_config"):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -20,11 +20,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_npu_available, is_torch_version
|
||||
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
|
||||
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
|
||||
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
silu_kernel = activation.layers.Silu
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
@@ -508,6 +529,7 @@ else:
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class RMSNorm(nn.Module):
|
||||
r"""
|
||||
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
|
||||
return selected_indices
|
||||
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
):
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
):
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
id_cond: Optional[torch.Tensor] = None,
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
):
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
|
||||
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
|
||||
@@ -120,7 +120,6 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
class BriaAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -162,12 +161,7 @@ class BriaAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -478,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
@@ -594,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`BriaTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`CogView3PlusTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -22,9 +22,9 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -41,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
gelu_tanh_kernel = activation.layers.GeluTanh
|
||||
|
||||
|
||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
@@ -74,7 +80,6 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
|
||||
|
||||
class FluxAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -116,12 +121,7 @@ class FluxAttnProcessor:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -143,7 +143,6 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
@@ -228,7 +227,6 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -261,7 +259,6 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
@@ -310,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
@@ -322,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
else:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
@@ -361,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
# self.act_mlp = nn.GELU(approximate="tanh")
|
||||
# else:
|
||||
# self.act_mlp = gelu_tanh_kernel()
|
||||
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
self.attn = FluxAttention(
|
||||
@@ -566,15 +580,6 @@ class FluxTransformer2DModel(
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
wtype = hidden_states.dtype
|
||||
(
|
||||
shift_msa_i,
|
||||
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor:
|
||||
return self.block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
):
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
|
||||
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
):
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -24,7 +24,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -52,7 +51,6 @@ class LTXVideoAttnProcessor:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
@@ -102,7 +100,6 @@ class LTXVideoAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -412,18 +409,6 @@ class LTXVideoTransformer3DModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -25,7 +25,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -262,7 +261,6 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -336,7 +334,6 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
@@ -505,18 +502,6 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -73,7 +73,6 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
|
||||
|
||||
class SkyReelsV2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -140,7 +139,6 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -153,7 +151,6 @@ class SkyReelsV2AttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
@@ -23,7 +23,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -67,7 +66,6 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
|
||||
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -134,7 +132,6 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -147,7 +144,6 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -543,19 +539,6 @@ class WanTransformer3DModel(
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
},
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -82,7 +82,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
@@ -101,23 +100,15 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = time_embed_dim
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
||||
)
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -25,7 +25,6 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import get_device
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -162,9 +161,7 @@ class AutoOffloadStrategy:
|
||||
|
||||
current_module_size = model.get_memory_footprint()
|
||||
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -304,7 +301,7 @@ class ComponentsManager:
|
||||
cm.add("vae", vae_model, collection="sdxl")
|
||||
|
||||
# Enable auto offloading
|
||||
cm.enable_auto_cpu_offload()
|
||||
cm.enable_auto_cpu_offload(device="cuda")
|
||||
|
||||
# Retrieve components
|
||||
unet = cm.get_one(name="unet", collection="sdxl")
|
||||
@@ -493,8 +490,6 @@ class ComponentsManager:
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# YiYi TODO: rename to search_components for now, may remove this method
|
||||
def search_components(
|
||||
@@ -683,7 +678,7 @@ class ComponentsManager:
|
||||
|
||||
return get_return_dict(matches, return_dict_with_names)
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
"""
|
||||
Enable automatic CPU offloading for all components.
|
||||
|
||||
@@ -709,8 +704,6 @@ class ComponentsManager:
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
|
||||
@@ -323,7 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not has_remote_code and trust_remote_code:
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
@@ -285,7 +285,6 @@ else:
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -394,7 +393,6 @@ else:
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
@@ -684,7 +682,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
@@ -722,7 +719,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
|
||||
@@ -688,11 +688,11 @@ class ChromaPipeline(
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -749,12 +749,12 @@ class ChromaImg2ImgPipeline(
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
strength (`float, *optional*, defaults to 0.9):
|
||||
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
||||
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lucy_edit import LucyEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,735 +0,0 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The Decart AI Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import LucyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> from typing import List
|
||||
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from diffusers import AutoencoderKLWan, LucyEditPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> # Arguments
|
||||
>>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
|
||||
>>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
|
||||
>>> negative_prompt = ""
|
||||
>>> num_frames = 81
|
||||
>>> height = 480
|
||||
>>> width = 832
|
||||
|
||||
|
||||
>>> # Load video
|
||||
>>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
|
||||
... video = load_video(url)[:num_frames]
|
||||
... video = [video[i].resize((width, height)) for i in range(num_frames)]
|
||||
... return video
|
||||
|
||||
|
||||
>>> video = load_video(url, convert_method=convert_video)
|
||||
|
||||
>>> # Load model
|
||||
>>> model_id = "decart-ai/Lucy-Edit-Dev"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Generate video
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... video=video,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=480,
|
||||
... width=832,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
|
||||
>>> # Export video
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for video-to-video generation using Lucy Edit.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
|
||||
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
|
||||
stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: Optional[WanTransformer3DModel] = None,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.register_to_config(expand_timesteps=expand_timesteps)
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if video is None:
|
||||
raise ValueError("`video` is required, received None.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_latent_frames = (
|
||||
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
|
||||
)
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
# Prepare noise latents
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# Prepare condition latents
|
||||
condition_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
|
||||
]
|
||||
|
||||
condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
device, dtype
|
||||
)
|
||||
|
||||
condition_latents = (condition_latents - latents_mean) * latents_std
|
||||
|
||||
# Check shapes
|
||||
assert latents.shape == condition_latents.shape, (
|
||||
f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
|
||||
)
|
||||
|
||||
return latents, condition_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: List[Image.Image],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[Image.Image]`):
|
||||
The video to use as the condition for the video generation.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~LucyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = (
|
||||
self.transformer.config.out_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.out_channels
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition_latents = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
# latent_model_input = latents.to(transformer_dtype)
|
||||
latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
|
||||
# latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LucyPipelineOutput(frames=video)
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LucyPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Lucy pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -19,12 +19,12 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
|
||||
from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
@@ -48,12 +48,10 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
@@ -114,9 +112,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -195,9 +191,7 @@ def filter_model_files(filenames):
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
@@ -218,9 +212,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -1110,7 +1102,7 @@ def _download_dduf_file(
|
||||
if not local_files_only:
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
@@ -23,7 +23,6 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import requests
|
||||
@@ -37,8 +36,9 @@ from huggingface_hub import (
|
||||
read_dduf_file,
|
||||
snapshot_download,
|
||||
)
|
||||
from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -1616,7 +1616,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not local_files_only:
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
@@ -28,7 +28,6 @@ else:
|
||||
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
|
||||
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
||||
|
||||
@@ -44,7 +43,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
|
||||
from .pipeline_qwenimage_edit import QwenImageEditPipeline
|
||||
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
|
||||
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
||||
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
|
||||
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
|
||||
else:
|
||||
|
||||
@@ -208,6 +208,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.vl_processor = processor
|
||||
self.tokenizer_max_length = 1024
|
||||
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
@@ -1,883 +0,0 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import QwenImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import QwenImageEditPlusPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
||||
... ).convert("RGB")
|
||||
>>> prompt = (
|
||||
... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
|
||||
... )
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("qwenimage_edit_plus.png")
|
||||
```
|
||||
"""
|
||||
|
||||
CONDITION_IMAGE_SIZE = 384 * 384
|
||||
VAE_IMAGE_SIZE = 1024 * 1024
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def calculate_dimensions(target_area, ratio):
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
r"""
|
||||
The Qwen-Image-Edit pipeline for image editing.
|
||||
|
||||
Args:
|
||||
transformer ([`QwenImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`QwenTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
processor: Qwen2VLProcessor,
|
||||
transformer: QwenImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 1024
|
||||
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_template_encode_start_idx = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
||||
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
return split_result
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
if isinstance(image, list):
|
||||
base_img_prompt = ""
|
||||
for i, img in enumerate(image):
|
||||
base_img_prompt += img_prompt_template.format(i + 1)
|
||||
elif image is not None:
|
||||
base_img_prompt = img_prompt_template.format(1)
|
||||
else:
|
||||
base_img_prompt = ""
|
||||
|
||||
template = self.prompt_template_encode
|
||||
|
||||
drop_idx = self.prompt_template_encode_start_idx
|
||||
txt = [template.format(base_img_prompt + e) for e in prompt]
|
||||
|
||||
model_inputs = self.processor(
|
||||
text=txt,
|
||||
images=image,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids=model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
pixel_values=model_inputs.pixel_values,
|
||||
image_grid_thw=model_inputs.image_grid_thw,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
image (`torch.Tensor`, *optional*):
|
||||
image to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
image_latents = (image_latents - latents_mean) / latents_std
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
images,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
image_latents = None
|
||||
if images is not None:
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
all_image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
image_latent_height, image_latent_width = image_latents.shape[3:]
|
||||
image_latents = self._pack_latents(
|
||||
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
||||
)
|
||||
all_image_latents.append(image_latents)
|
||||
image_latents = torch.cat(all_image_latents, dim=1)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
true_cfg_scale: float = 4.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||
not greater than `1`).
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
||||
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
||||
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
||||
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
||||
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
||||
lower image quality.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
||||
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
||||
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
||||
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
||||
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
||||
enable classifier-free guidance computations).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
image_size = image[-1].size if isinstance(image, list) else image.size
|
||||
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
||||
height = height or calculated_height
|
||||
width = width or calculated_width
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
width = width // multiple_of * multiple_of
|
||||
height = height // multiple_of * multiple_of
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
condition_image_sizes = []
|
||||
condition_images = []
|
||||
vae_image_sizes = []
|
||||
vae_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
condition_width, condition_height = calculate_dimensions(
|
||||
CONDITION_IMAGE_SIZE, image_width / image_height
|
||||
)
|
||||
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
|
||||
condition_image_sizes.append((condition_width, condition_height))
|
||||
vae_image_sizes.append((vae_width, vae_height))
|
||||
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
||||
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
|
||||
if true_cfg_scale > 1 and not has_neg_prompt:
|
||||
logger.warning(
|
||||
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
||||
)
|
||||
elif true_cfg_scale <= 1 and has_neg_prompt:
|
||||
logger.warning(
|
||||
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
||||
)
|
||||
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
image=condition_images,
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
image=condition_images,
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, image_latents = self.prepare_latents(
|
||||
vae_images,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
img_shapes = [
|
||||
[
|
||||
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
||||
for vae_width, vae_height in vae_image_sizes
|
||||
],
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
||||
elif self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
||||
logger.warning(
|
||||
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
||||
)
|
||||
guidance = None
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
latent_model_input = latents
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
@@ -152,26 +152,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,8 +170,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -190,10 +178,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
@@ -334,7 +321,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video=None,
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
@@ -346,8 +332,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -683,7 +667,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -745,10 +728,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -795,7 +774,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Simplification of implementation for now
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
|
||||
if num_videos_per_prompt != 1:
|
||||
raise ValueError(
|
||||
@@ -814,7 +793,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -824,11 +802,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -922,53 +896,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -648,21 +648,6 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ContextParallelConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1068,21 +1053,6 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ParallelConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1592,21 +1592,6 @@ class LTXPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LucyEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1892,21 +1877,6 @@ class QwenImageEditPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageEditPlusPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -247,7 +247,6 @@ def find_pipeline_class(loaded_module):
|
||||
def get_cached_module_file(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
@@ -354,7 +353,6 @@ def get_cached_module_file(
|
||||
resolved_module_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -412,7 +410,6 @@ def get_cached_module_file(
|
||||
get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
f"{module_needed}.py",
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -427,7 +424,6 @@ def get_cached_module_file(
|
||||
def get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
class_name: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
@@ -501,7 +497,6 @@ def get_class_from_dynamic_module(
|
||||
final_module = get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
|
||||
@@ -38,13 +38,13 @@ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
is_jinja_available,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from packaging import version
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
@@ -316,7 +316,7 @@ def _get_model_file(
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
) from e
|
||||
except HfHubHTTPError as e:
|
||||
except HTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
|
||||
) from e
|
||||
@@ -432,7 +432,7 @@ def _get_checkpoint_shard_files(
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HfHubHTTPError as e:
|
||||
except HTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||
"RMSNorm": {
|
||||
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
|
||||
},
|
||||
}
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
else:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
class LayerRepository:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||
)
|
||||
|
||||
def register_kernel_mapping(*args, **kwargs):
|
||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
@@ -43,6 +43,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
@@ -43,6 +44,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
scheduler_cls = CogVideoXDPMScheduler
|
||||
scheduler_kwargs = {"timestep_spacing": "trailing"}
|
||||
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
|
||||
|
||||
transformer_kwargs = {
|
||||
"num_attention_heads": 4,
|
||||
|
||||
@@ -50,6 +50,7 @@ class TokenizerWrapper:
|
||||
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogView4Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
@@ -123,29 +124,30 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -55,8 +55,9 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
@require_peft_backend
|
||||
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler_kwargs = {}
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
transformer_kwargs = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
@@ -281,8 +282,9 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxControlPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler_kwargs = {}
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
transformer_kwargs = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 8,
|
||||
@@ -905,13 +907,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
|
||||
@@ -51,6 +51,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
@@ -253,7 +254,6 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
|
||||
("xpu", 3): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTXPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
@@ -39,6 +39,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Lumina2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
@@ -140,30 +141,33 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = MochiPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
@@ -31,8 +31,9 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
@require_peft_backend
|
||||
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = SanaPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {"shift": 7.0}
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
scheduler_kwargs = {}
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
transformer_kwargs = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
|
||||
@@ -55,6 +55,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
transformer_kwargs = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
|
||||
@@ -42,6 +42,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
|
||||
@@ -50,6 +50,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
@@ -164,8 +165,9 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules_wanvace(self):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
exclude_module_name = "vace_blocks.0.proj_out"
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
|
||||
+1160
-1081
File diff suppressed because it is too large
Load Diff
@@ -37,8 +37,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
|
||||
from huggingface_hub.utils import HfHubHTTPError, is_jinja_available
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
@@ -271,7 +272,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
@@ -295,7 +296,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
error_response = mock.Mock(
|
||||
status_code=500,
|
||||
headers={},
|
||||
raise_for_status=mock.Mock(side_effect=HfHubHTTPError("Server down", response=mock.Mock())),
|
||||
raise_for_status=mock.Mock(side_effect=HTTPError),
|
||||
json=mock.Mock(return_value={}),
|
||||
)
|
||||
|
||||
|
||||
@@ -48,7 +48,6 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
test_xformers_attention = False
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
|
||||
@@ -47,8 +47,8 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
supports_dduf = False
|
||||
|
||||
|
||||
@@ -18,13 +18,11 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import XLMRobertaTokenizerFast
|
||||
|
||||
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
|
||||
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -217,11 +215,6 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
dummy = Dummies()
|
||||
return dummy.get_dummy_inputs(device=device, seed=seed)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky(self):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
@@ -16,10 +16,8 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
@@ -75,11 +73,6 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
)
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky(self):
|
||||
device = "cpu"
|
||||
|
||||
@@ -188,11 +181,6 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
inputs.pop("negative_image_embeds")
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky(self):
|
||||
device = "cpu"
|
||||
|
||||
@@ -304,11 +292,6 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
inputs.pop("negative_image_embeds")
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky(self):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import XLMRobertaTokenizerFast
|
||||
@@ -32,7 +31,6 @@ from diffusers import (
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -239,11 +237,6 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
dummies = Dummies()
|
||||
return dummies.get_dummy_inputs(device=device, seed=seed)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky_img2img(self):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
@@ -18,14 +18,12 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import XLMRobertaTokenizerFast
|
||||
|
||||
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
|
||||
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -233,11 +231,6 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
dummies = Dummies()
|
||||
return dummies.get_dummy_inputs(device=device, seed=seed)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.56.2"),
|
||||
reason="Latest transformers changes the slices",
|
||||
strict=False,
|
||||
)
|
||||
def test_kandinsky_inpaint(self):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user