Compare commits

..

6 Commits

Author SHA1 Message Date
sayakpaul 92199ff3ac up 2025-09-22 16:46:49 +05:30
sayakpaul 04e9323055 up 2025-09-18 17:23:04 +05:30
sayakpaul 9a09162baf up 2025-09-18 14:59:00 +05:30
sayakpaul 33a8a3be0c up 2025-09-18 14:49:48 +05:30
sayakpaul 58743c3ee7 kernelize gelu. 2025-09-16 18:09:12 +05:30
sayakpaul 50c0b786d2 start kernelize. 2025-09-15 16:26:52 +05:30
105 changed files with 2591 additions and 5314 deletions
+5 -3
View File
@@ -23,7 +23,11 @@
- local: using-diffusers/reusing_seeds
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
title: Load schedulers and models
- local: using-diffusers/models
title: Models
- local: using-diffusers/scheduler_features
title: Scheduler features
- local: using-diffusers/other-formats
title: Model files and layouts
- local: using-diffusers/push_to_hub
@@ -64,8 +68,6 @@
title: Accelerate inference
- local: optimization/cache
title: Caching
- local: optimization/attention_backends
title: Attention backends
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
+1 -33
View File
@@ -26,7 +26,6 @@ Qwen-Image comes in the following variants:
|:----------:|:--------:|
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
<Tip>
@@ -97,29 +96,6 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
</Tip>
## Multi-image reference with QwenImageEditPlusPipeline
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
```
import torch
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
from diffusers.utils import load_image
pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
).to("cuda")
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt="put the penguin and the cat at a game show called "Qwen Edit Plus Games"",
num_inference_steps=50
).images[0]
```
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
@@ -150,15 +126,7 @@ image = pipe(
- all
- __call__
## QwenImageControlNetPipeline
[[autodoc]] QwenImageControlNetPipeline
- all
- __call__
## QwenImageEditPlusPipeline
[[autodoc]] QwenImageEditPlusPipeline
## QwenImaggeControlNetPipeline
- all
- __call__
@@ -1,106 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Attention backends
> [!TIP]
> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
This guide will show you how to set and use the different attention backends.
## set_attention_backend
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
> [!TIP]
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
```py
import torch
from diffusers import QwenImagePipeline
pipeline = QwenImagePipeline.from_pretrained(
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
)
pipeline.transformer.set_attention_backend("_flash_3_hub")
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
pipeline(prompt).images[0]
```
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
```py
pipeline.transformer.reset_attention_backend()
```
## attention_backend context manager
The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
```py
import torch
from diffusers import QwenImagePipeline
pipeline = QwenImagePipeline.from_pretrained(
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
with attention_backend("_flash_3_hub"):
image = pipeline(prompt).images[0]
```
## Available backends
Refer to the table below for a complete list of available attention backends and their variants.
| Backend Name | Family | Description |
|--------------|--------|-------------|
| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
+120
View File
@@ -0,0 +1,120 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
[[open-in-colab]]
# Models
A diffusion model relies on a few individual models working together to generate an output. These models are responsible for denoising, encoding inputs, and decoding latents into the actual outputs.
This guide will show you how to load models.
## Loading a model
All models are loaded with the [`~ModelMixin.from_pretrained`] method, which downloads and caches the latest model version. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache.
Pass the `subfolder` argument to [`~ModelMixin.from_pretrained`] to specify where to load the model weights from. Omit the `subfolder` argument if the repository doesn't have a subfolder structure or if you're loading a standalone model.
```py
from diffusers import QwenImageTransformer2DModel
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
```
## AutoModel
[`AutoModel`] detects the model class from a `model_index.json` file or a model's `config.json` file. It fetches the correct model class from these files and delegates the actual loading to the model class. [`AutoModel`] is useful for automatic model type detection without needing to know the exact model class beforehand.
```py
from diffusers import AutoModel
model = AutoModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="transformer"
)
```
## Model data types
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to load a model with a specific data type. This allows you to load a model in a lower precision to reduce memory usage.
```py
import torch
from diffusers import QwenImageTransformer2DModel
model = QwenImageTransformer2DModel.from_pretrained(
"Qwen/Qwen-Image",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
```
[nn.Module.to](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to) can also convert to a specific data type on the fly. However, it converts *all* weights to the requested data type unlike `torch_dtype` which respects `_keep_in_fp32_modules`. This argument preserves layers in `torch.float32` for numerical stability and best generation quality (see example [_keep_in_fp32_modules](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))
```py
from diffusers import QwenImageTransformer2DModel
model = QwenImageTransformer2DModel.from_pretrained(
"Qwen/Qwen-Image", subfolder="transformer"
)
model = model.to(dtype=torch.float16)
```
## Device placement
Use the `device_map` argument in [`~ModelMixin.from_pretrained`] to place a model on an accelerator like a GPU. It is especially helpful where there are multiple GPUs.
Diffusers currently provides three options to `device_map` for individual models, `"cuda"`, `"balanced"` and `"auto"`. Refer to the table below to compare the three placement strategies.
| parameter | description |
|---|---|
| `"cuda"` | places pipeline on a supported accelerator (CUDA) |
| `"balanced"` | evenly distributes pipeline on all GPUs |
| `"auto"` | distribute model from fastest device first to slowest |
Use the `max_memory` argument in [`~ModelMixin.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
```py
import torch
from diffusers import QwenImagePipeline
max_memory = {0: "16GB", 1: "16GB"}
pipeline = QwenImagePipeline.from_pretrained(
"Qwen/Qwen-Image",
torch_dtype=torch.bfloat16,
device_map="cuda",
max_memory=max_memory
)
```
The `hf_device_map` attribute allows you to access and view the `device_map`.
```py
print(transformer.hf_device_map)
# {'': device(type='cuda')}
```
## Saving models
Save a model with the [`~ModelMixin.save_pretrained`] method.
```py
from diffusers import QwenImageTransformer2DModel
model = QwenImageTransformer2DModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer")
model.save_pretrained("./local/model")
```
For large models, it is helpful to use `max_shard_size` to save a model as multiple shards. A shard can be loaded faster and save memory (refer to the [parallel loading](./loading#parallel-loading) docs for more details), especially if there is more than one GPU.
```py
model.save_pretrained("./local/model", max_shard_size="5GB")
```
@@ -0,0 +1,235 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Scheduler features
The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
This guide will demonstrate how to use these features to improve inference quality.
> [!TIP]
> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
## Timestep schedules
The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
```py
from diffusers.schedulers import AysSchedules
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
print(sampling_schedule)
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
```
You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
```py
pipeline = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
generator = torch.Generator(device="cpu").manual_seed(2487854446)
image = pipeline(
prompt=prompt,
negative_prompt="",
generator=generator,
timesteps=sampling_schedule,
).images[0]
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
</div>
</div>
## Timestep spacing
The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
- `leading` creates evenly spaced steps
- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
```py
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
pipeline = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
generator = torch.Generator(device="cpu").manual_seed(2487854446)
image = pipeline(
prompt=prompt,
negative_prompt="",
generator=generator,
num_inference_steps=5,
).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
</div>
</div>
## Sigmas
The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
```py
import torch
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
prompt = "anthropomorphic capybara wearing a suit and working with a computer"
generator = torch.Generator(device='cuda').manual_seed(123)
image = pipeline(
prompt=prompt,
num_inference_steps=10,
sigmas=sigmas,
generator=generator
).images[0]
```
When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
```py
print(f" timesteps: {pipe.scheduler.timesteps}")
"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
```
### Karras sigmas
> [!TIP]
> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
>
> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
```py
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
pipeline = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
generator = torch.Generator(device="cpu").manual_seed(2487854446)
image = pipeline(
prompt=prompt,
negative_prompt="",
generator=generator,
).images[0]
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
</div>
</div>
## Rescale noise schedule
In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
> [!TIP]
> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
>
> ```bash
> --prediction_type="v_prediction"
> ```
For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
* `timestep_spacing="trailing"` to start sampling from the last timestep
Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
```py
from diffusers import DiffusionPipeline, DDIMScheduler
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
pipeline.scheduler = DDIMScheduler.from_config(
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipeline.to("cuda")
prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
generator = torch.Generator(device="cpu").manual_seed(23)
image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
</div>
</div>
+166 -239
View File
@@ -10,273 +10,200 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Load schedulers and models
[[open-in-colab]]
# Schedulers
Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
This guide will show you how to load and customize schedulers.
## Loading schedulers
Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
)
pipeline.scheduler
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
```
Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
```py
pipeline.scheduler
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.21.4",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": null
}
```
## Load a scheduler
Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
For example, to load the [`DDIMScheduler`]:
```py
from diffusers import DDIMScheduler, DiffusionPipeline
ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
```
Then you can pass the newly loaded scheduler to the pipeline.
```python
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
```
## Compare schedulers
Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
generator = torch.Generator(device="cuda").manual_seed(8)
```
To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
<hfoptions id="schedulers">
<hfoption id="LMSDiscreteScheduler">
[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
```py
from diffusers import LMSDiscreteScheduler
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
image = pipeline(prompt, generator=generator).images[0]
image
```
</hfoption>
<hfoption id="EulerDiscreteScheduler">
[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
```py
from diffusers import EulerDiscreteScheduler
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
image = pipeline(prompt, generator=generator).images[0]
image
```
</hfoption>
<hfoption id="EulerAncestralDiscreteScheduler">
[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
```py
from diffusers import EulerAncestralDiscreteScheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
image = pipeline(prompt, generator=generator).images[0]
image
```
</hfoption>
<hfoption id="DPMSolverMultistepScheduler">
[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
```py
from diffusers import DPMSolverMultistepScheduler
dpm = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
)
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
scheduler=dpm,
torch_dtype=torch.float16,
device_map="cuda"
)
pipeline.scheduler
```
## Timestep schedules
Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
> [!TIP]
> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
Import the schedule and pass it to the `timesteps` argument in the pipeline.
```py
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.schedulers import AysSchedules
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
print(sampling_schedule)
"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
pipeline = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
device_map="cuda"
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
)
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
image = pipeline(
prompt=prompt,
negative_prompt="",
timesteps=sampling_schedule,
).images[0]
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
</div>
</div>
### Rescaling schedules
Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
> [!TIP]
> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
```py
from diffusers import DiffusionPipeline, DDIMScheduler
pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
pipeline.scheduler = DDIMScheduler.from_config(
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
```
Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
```py
prompt = """
cinematic photo of a snowy mountain at night with the northern lights aurora borealis
overhead, 35mm photograph, film, professional, 4k, highly detailed
"""
image = pipeline(prompt, guidance_rescale=0.7).images[0]
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">default Stable Diffusion v2-1 image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">image with zero SNR and trailing timestep spacing enabled</figcaption>
</div>
</div>
## Timestep spacing
Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
| spacing strategy | spacing calculation | example timesteps |
|---|---|---|
| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
> [!TIP]
> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
```py
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipeline = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
device_map="cuda"
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, timestep_spacing="trailing"
)
prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
image = pipeline(
prompt=prompt,
negative_prompt="",
num_inference_steps=5,
).images[0]
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
image = pipeline(prompt, generator=generator).images[0]
image
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">trailing spacing after 5 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">leading spacing after 5 steps</figcaption>
</div>
</div>
## Sigmas
Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
> [!TIP]
> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
```py
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipeline = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
device_map="cuda"
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
)
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
image = pipeline(
prompt=prompt,
negative_prompt="",
sigmas=sigmas,
).images[0]
```
### Karras sigmas
[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
Set `use_karras_sigmas=True` in the scheduler to enable it.
```py
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipeline = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
device_map="cuda"
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
image = pipeline(
prompt=prompt,
negative_prompt="",
sigmas=sigmas,
).images[0]
```
</hfoption>
</hfoptions>
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas enabled</figcaption>
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">LMSDiscreteScheduler</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Karras sigmas disabled</figcaption>
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerDiscreteScheduler</figcaption>
</div>
</div>
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">EulerAncestralDiscreteScheduler</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">DPMSolverMultistepScheduler</figcaption>
</div>
</div>
Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
## Choosing a scheduler
## Models
It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
- DPM++ 2M SDE Karras is generally a good all-purpose option.
- [`TCDScheduler`] works well for distilled models.
- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
## Resources
```python
from diffusers import UNet2DConditionModel
- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
```
They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
```python
from diffusers import UNet2DModel
unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
```
To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
```python
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
)
unet.save_pretrained("./local-unet", variant="non_ema")
```
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
```py
from diffusers import AutoModel
unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
)
```
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
-91
View File
@@ -1,91 +0,0 @@
import logging
import os
from dataclasses import dataclass, field
from typing import List
import torch
from pydantic import BaseModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
logger = logging.getLogger(__name__)
class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None
@dataclass
class PresetModels:
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
SD3_5: List[str] = field(
default_factory=lambda: [
"stabilityai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-medium",
]
)
class TextToImagePipelineSD3:
def __init__(self, model_path: str | None = None):
self.model_path = model_path or os.getenv("MODEL_PATH")
self.pipeline: StableDiffusion3Pipeline | None = None
self.device: str | None = None
def start(self):
if torch.cuda.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
class ModelPipelineInitializer:
def __init__(self, model: str = "", type_models: str = "t2im"):
self.model = model
self.type_models = type_models
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "mps"
self.model_type = None
def initialize_pipeline(self):
if not self.model:
raise ValueError("Model name not provided")
# Check if model exists in PresetModels
preset_models = PresetModels()
# Determine which model type we're dealing with
if self.model in preset_models.SD3:
self.model_type = "SD3"
elif self.model in preset_models.SD3_5:
self.model_type = "SD3_5"
# Create appropriate pipeline based on model type and type_models
if self.type_models == "t2im":
if self.model_type in ["SD3", "SD3_5"]:
self.pipeline = TextToImagePipelineSD3(self.model)
else:
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
elif self.type_models == "t2v":
raise ValueError(f"Unsupported type_models: {self.type_models}")
return self.pipeline
-171
View File
@@ -1,171 +0,0 @@
# Asynchronous server and parallel execution of models
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
## ⚠️ IMPORTANT
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
## Necessary components
All the components needed to create the inference server are in the current directory:
```
server-async/
├── utils/
├─────── __init__.py
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
├─────── utils.py # Image/video saving utilities and service configuration
├── Pipelines.py # pipeline loader classes (SD3)
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
├── test.py # Client test script for inference requests
├── requirements.txt # Dependencies
└── README.md # This documentation
```
## What `diffusers-async` adds / Why we needed it
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
`diffusers-async` / this example addresses that by:
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
## How the server works (high-level flow)
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
2. On each HTTP inference request:
* The server uses `RequestScopedPipeline.generate(...)` which:
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
* sets `local_pipe.scheduler = local_scheduler` (if possible),
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
* wraps tokenizers with thread-safe locks to prevent race conditions,
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
* calls the pipeline on the local view (`local_pipe(...)`).
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
## How to set up and run the server
### 1) Install dependencies
Recommended: create a virtualenv / conda environment.
```bash
pip install diffusers
pip install -r requirements.txt
```
### 2) Start the server
Using the `serverasync.py` file that already has everything you need:
```bash
python serverasync.py
```
The server will start on `http://localhost:8500` by default with the following features:
- FastAPI application with async lifespan management
- Automatic model loading and pipeline initialization
- Request counting and active inference tracking
- Memory cleanup after each inference
- CORS middleware for cross-origin requests
### 3) Test the server
Use the included test script:
```bash
python test.py
```
Or send a manual request:
`POST /api/diffusers/inference` with JSON body:
```json
{
"prompt": "A futuristic cityscape, vibrant colors",
"num_inference_steps": 30,
"num_images_per_prompt": 1
}
```
Response example:
```json
{
"response": ["http://localhost:8500/images/img123.png"]
}
```
### 4) Server endpoints
- `GET /` - Welcome message
- `POST /api/diffusers/inference` - Main inference endpoint
- `GET /images/{filename}` - Serve generated images
- `GET /api/status` - Server status and memory info
## Advanced Configuration
### RequestScopedPipeline Parameters
```python
RequestScopedPipeline(
pipeline, # Base pipeline to wrap
mutable_attrs=None, # Custom list of attributes to clone
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
tokenizer_lock=None, # Custom threading lock for tokenizers
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
)
```
### BaseAsyncScheduler Features
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
* `clone_for_request()` method for safe per-request scheduler cloning
* Enhanced debugging with `__repr__` and `__str__` methods
* Full compatibility with existing scheduler APIs
### Server Configuration
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
```python
@dataclass
class ServerConfigModels:
model: str = 'stabilityai/stable-diffusion-3.5-medium'
type_models: str = 't2im'
host: str = '0.0.0.0'
port: int = 8500
```
## Troubleshooting (quick)
* `Already borrowed` — previously a Rust tokenizer concurrency error.
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
* Scheduler issues:
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
* Memory issues with large tensors:
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
* Automatic tokenizer detection:
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
-10
View File
@@ -1,10 +0,0 @@
torch
torchvision
transformers
sentencepiece
fastapi
uvicorn
ftfy
accelerate
xformers
protobuf
-230
View File
@@ -1,230 +0,0 @@
import asyncio
import gc
import logging
import os
import random
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from Pipelines import ModelPipelineInitializer
from pydantic import BaseModel
from utils import RequestScopedPipeline, Utils
@dataclass
class ServerConfigModels:
model: str = "stabilityai/stable-diffusion-3.5-medium"
type_models: str = "t2im"
constructor_pipeline: Optional[Type] = None
custom_pipeline: Optional[Type] = None
components: Optional[Dict[str, Any]] = None
torch_dtype: Optional[torch.dtype] = None
host: str = "0.0.0.0"
port: int = 8500
server_config = ServerConfigModels()
@asynccontextmanager
async def lifespan(app: FastAPI):
logging.basicConfig(level=logging.INFO)
app.state.logger = logging.getLogger("diffusers-server")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
app.state.total_requests = 0
app.state.active_inferences = 0
app.state.metrics_lock = asyncio.Lock()
app.state.metrics_task = None
app.state.utils_app = Utils(
host=server_config.host,
port=server_config.port,
)
async def metrics_loop():
try:
while True:
async with app.state.metrics_lock:
total = app.state.total_requests
active = app.state.active_inferences
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
await asyncio.sleep(5)
except asyncio.CancelledError:
app.state.logger.info("Metrics loop cancelled")
raise
app.state.metrics_task = asyncio.create_task(metrics_loop())
try:
yield
finally:
task = app.state.metrics_task
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
try:
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
if callable(stop_fn):
await run_in_threadpool(stop_fn)
except Exception as e:
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
app.state.logger.info("Lifespan shutdown complete")
app = FastAPI(lifespan=lifespan)
logger = logging.getLogger("DiffusersServer.Pipelines")
initializer = ModelPipelineInitializer(
model=server_config.model,
type_models=server_config.type_models,
)
model_pipeline = initializer.initialize_pipeline()
model_pipeline.start()
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
pipeline_lock = threading.Lock()
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
app.state.MODEL_INITIALIZER = initializer
app.state.MODEL_PIPELINE = model_pipeline
app.state.REQUEST_PIPE = request_pipe
app.state.PIPELINE_LOCK = pipeline_lock
class JSONBodyQueryAPI(BaseModel):
model: str | None = None
prompt: str
negative_prompt: str | None = None
num_inference_steps: int = 28
num_images_per_prompt: int = 1
@app.middleware("http")
async def count_requests_middleware(request: Request, call_next):
async with app.state.metrics_lock:
app.state.total_requests += 1
response = await call_next(request)
return response
@app.get("/")
async def root():
return {"message": "Welcome to the Diffusers Server"}
@app.post("/api/diffusers/inference")
async def api(json: JSONBodyQueryAPI):
prompt = json.prompt
negative_prompt = json.negative_prompt or ""
num_steps = json.num_inference_steps
num_images_per_prompt = json.num_images_per_prompt
wrapper = app.state.MODEL_PIPELINE
initializer = app.state.MODEL_INITIALIZER
utils_app = app.state.utils_app
if not wrapper or not wrapper.pipeline:
raise HTTPException(500, "Model not initialized correctly")
if not prompt.strip():
raise HTTPException(400, "No prompt provided")
def make_generator():
g = torch.Generator(device=initializer.device)
return g.manual_seed(random.randint(0, 10_000_000))
req_pipe = app.state.REQUEST_PIPE
def infer():
gen = make_generator()
return req_pipe.generate(
prompt=prompt,
negative_prompt=negative_prompt,
generator=gen,
num_inference_steps=num_steps,
num_images_per_prompt=num_images_per_prompt,
device=initializer.device,
output_type="pil",
)
try:
async with app.state.metrics_lock:
app.state.active_inferences += 1
output = await run_in_threadpool(infer)
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
urls = [utils_app.save_image(img) for img in output.images]
return {"response": urls}
except Exception as e:
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
logger.error(f"Error during inference: {e}")
raise HTTPException(500, f"Error in processing: {e}")
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.ipc_collect()
gc.collect()
@app.get("/images/{filename}")
async def serve_image(filename: str):
utils_app = app.state.utils_app
file_path = os.path.join(utils_app.image_dir, filename)
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(file_path, media_type="image/png")
@app.get("/api/status")
async def get_status():
memory_info = {}
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
memory_info = {
"memory_allocated_gb": round(memory_allocated, 2),
"memory_reserved_gb": round(memory_reserved, 2),
"device": torch.cuda.get_device_name(0),
}
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=server_config.host, port=server_config.port)
-65
View File
@@ -1,65 +0,0 @@
import os
import time
import urllib.parse
import requests
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
BASE_URL = "http://localhost:8500"
DOWNLOAD_FOLDER = "generated_images"
WAIT_BEFORE_DOWNLOAD = 2 # seconds
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
def save_from_url(url: str) -> str:
"""Download the given URL (relative or absolute) and save it locally."""
if url.startswith("/"):
direct = BASE_URL.rstrip("/") + url
else:
direct = url
resp = requests.get(direct, timeout=60)
resp.raise_for_status()
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
path = os.path.join(DOWNLOAD_FOLDER, filename)
with open(path, "wb") as f:
f.write(resp.content)
return path
def main():
payload = {
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
"num_inference_steps": 30,
"num_images_per_prompt": 1,
}
print("Sending request...")
try:
r = requests.post(SERVER_URL, json=payload, timeout=480)
r.raise_for_status()
except Exception as e:
print(f"Request failed: {e}")
return
body = r.json().get("response", [])
# Normalize to a list
urls = body if isinstance(body, list) else [body] if body else []
if not urls:
print("No URLs found in the response. Check the server output.")
return
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
time.sleep(WAIT_BEFORE_DOWNLOAD)
for u in urls:
try:
path = save_from_url(u)
print(f"Image saved to: {path}")
except Exception as e:
print(f"Error downloading {u}: {e}")
if __name__ == "__main__":
main()
-2
View File
@@ -1,2 +0,0 @@
from .requestscopedpipeline import RequestScopedPipeline
from .utils import Utils
@@ -1,296 +0,0 @@
import copy
import threading
from typing import Any, Iterable, List, Optional
import torch
from diffusers.utils import logging
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
logger = logging.get_logger(__name__)
def safe_tokenize(tokenizer, *args, lock, **kwargs):
with lock:
return tokenizer(*args, **kwargs)
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
"_offload_device",
"_progress_bar_config",
"_progress_bar",
"_rng_state",
"_last_seed",
"latents",
]
def __init__(
self,
pipeline: Any,
mutable_attrs: Optional[Iterable[str]] = None,
auto_detect_mutables: bool = True,
tensor_numel_threshold: int = 1_000_000,
tokenizer_lock: Optional[threading.Lock] = None,
wrap_scheduler: bool = True,
):
self._base = pipeline
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
return None
if not isinstance(base_sched, BaseAsyncScheduler):
wrapped_scheduler = BaseAsyncScheduler(base_sched)
else:
wrapped_scheduler = base_sched
try:
return wrapped_scheduler.clone_for_request(
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
try:
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
return wrapped_scheduler
def _autodetect_mutables(self, max_attrs: int = 40):
if not self._auto_detect_mutables:
return []
if self._auto_detected_attrs:
return self._auto_detected_attrs
candidates: List[str] = []
seen = set()
for name in dir(self._base):
if name.startswith("__"):
continue
if name in self._mutable_attrs:
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue
try:
val = getattr(self._base, name)
except Exception:
continue
import types
# skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue
# containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
else:
# try Tensor detection
try:
if isinstance(val, torch.Tensor):
if val.numel() <= self._tensor_numel_threshold:
candidates.append(name)
seen.add(name)
else:
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
except Exception:
continue
if len(candidates) >= max_attrs:
break
self._auto_detected_attrs = candidates
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
return self._auto_detected_attrs
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
try:
cls = type(base_obj)
descriptor = getattr(cls, attr_name, None)
if isinstance(descriptor, property):
return descriptor.fset is None
if hasattr(descriptor, "__set__") is False and descriptor is not None:
return False
except Exception:
pass
return False
def _clone_mutable_attrs(self, base, local):
attrs_to_clone = list(self._mutable_attrs)
attrs_to_clone.extend(self._autodetect_mutables())
EXCLUDE_ATTRS = {
"components",
}
for attr in attrs_to_clone:
if attr in EXCLUDE_ATTRS:
logger.debug(f"Skipping excluded attr '{attr}'")
continue
if not hasattr(base, attr):
continue
if self._is_readonly_property(base, attr):
logger.debug(f"Skipping read-only property '{attr}'")
continue
try:
val = getattr(base, attr)
except Exception as e:
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
continue
try:
if isinstance(val, dict):
setattr(local, attr, dict(val))
elif isinstance(val, (list, tuple, set)):
setattr(local, attr, list(val))
elif isinstance(val, bytearray):
setattr(local, attr, bytearray(val))
else:
# small tensors or atomic values
if isinstance(val, torch.Tensor):
if val.numel() <= self._tensor_numel_threshold:
setattr(local, attr, val.clone())
else:
# don't clone big tensors, keep reference
setattr(local, attr, val)
else:
try:
setattr(local, attr, copy.copy(val))
except Exception:
setattr(local, attr, val)
except (AttributeError, TypeError) as e:
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
continue
except Exception as e:
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
continue
def _is_tokenizer_component(self, component) -> bool:
if component is None:
return False
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
class_name = component.__class__.__name__.lower()
has_tokenizer_in_name = "tokenizer" in class_name
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
try:
local_pipe = copy.copy(self._base)
except Exception as e:
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
local_scheduler.scheduler,
num_inference_steps=num_inference_steps,
device=device,
return_scheduler=True,
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
)
final_scheduler = BaseAsyncScheduler(configured_scheduler)
setattr(local_pipe, "scheduler", final_scheduler)
except Exception:
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
self._clone_mutable_attrs(self._base, local_pipe)
# 4) wrap tokenizers on the local pipe with the lock wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(
tok, *args, lock=self._tokenizer_lock, **kwargs
),
)
# b) wrap tokenizers in components dict
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except Exception as e:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
for name, tok in tokenizer_wrappers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring wrapped tokenizers: {e}")
-141
View File
@@ -1,141 +0,0 @@
import copy
import inspect
from typing import Any, List, Optional, Union
import torch
class BaseAsyncScheduler:
def __init__(self, scheduler: Any):
self.scheduler = scheduler
def __getattr__(self, name: str):
if hasattr(self.scheduler, name):
return getattr(self.scheduler, name)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value):
if name == "scheduler":
super().__setattr__(name, value)
else:
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
setattr(self.scheduler, name, value)
else:
super().__setattr__(name, value)
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
local = copy.deepcopy(self.scheduler)
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
cloned = self.__class__(local)
return cloned
def __repr__(self):
return f"BaseAsyncScheduler({repr(self.scheduler)})"
def __str__(self):
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
def async_retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Backwards compatible: by default the function behaves exactly as before and returns
(timesteps_tensor, num_inference_steps)
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
(timesteps_tensor, num_inference_steps, scheduler_in_use)
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Optional kwargs:
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
Returns:
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
"""
# pop our optional control kwarg (keeps compatibility)
return_scheduler = bool(kwargs.pop("return_scheduler", False))
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# choose scheduler to call set_timesteps on
scheduler_in_use = scheduler
if return_scheduler:
# Do not mutate the provided scheduler: prefer to clone if possible
if hasattr(scheduler, "clone_for_request"):
try:
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
scheduler_in_use = scheduler.clone_for_request(
num_inference_steps=num_inference_steps or 0, device=device
)
except Exception:
scheduler_in_use = copy.deepcopy(scheduler)
else:
# fallback deepcopy (scheduler tends to be smallish - acceptable)
scheduler_in_use = copy.deepcopy(scheduler)
# helper to test if set_timesteps supports a particular kwarg
def _accepts(param_name: str) -> bool:
try:
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
except (ValueError, TypeError):
# if signature introspection fails, be permissive and attempt the call later
return False
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
if timesteps is not None:
accepts_timesteps = _accepts("timesteps")
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
elif sigmas is not None:
accept_sigmas = _accepts("sigmas")
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
else:
# default path
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
if return_scheduler:
return timesteps_out, num_inference_steps, scheduler_in_use
return timesteps_out, num_inference_steps
-48
View File
@@ -1,48 +0,0 @@
import gc
import logging
import os
import tempfile
import uuid
import torch
logger = logging.getLogger(__name__)
class Utils:
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
self.service_url = f"http://{host}:{port}"
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
if not os.path.exists(self.video_dir):
os.makedirs(self.video_dir)
def save_image(self, image):
if hasattr(image, "to"):
try:
image = image.to("cpu")
except Exception:
pass
if isinstance(image, torch.Tensor):
from torchvision import transforms
to_pil = transforms.ToPILImage()
image = to_pil(image.squeeze(0).clamp(0, 1))
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
image_path = os.path.join(self.image_dir, filename)
logger.info(f"Saving image to {image_path}")
image.save(image_path, format="PNG", optimize=True)
del image
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return os.path.join(self.service_url, "images", filename)
+2 -2
View File
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
pip install diffusers
pip install -r requirements.txt
pip install .
pip install -f requirements.txt
```
Launch the server with the following command.
+1 -2
View File
@@ -6,5 +6,4 @@ py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
uvicorn
accelerate
uvicorn
+1 -1
View File
@@ -39,7 +39,7 @@ fsspec==2024.10.0
# torch
h11==0.14.0
# via uvicorn
huggingface-hub==0.35.0
huggingface-hub==0.26.1
# via
# tokenizers
# transformers
+1 -34
View File
@@ -278,29 +278,6 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-VACE-Fun-14B":
config = {
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-I2V-14B-720p":
config = {
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
@@ -998,17 +975,7 @@ if __name__ == "__main__":
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "Wan2.2-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
transformer_2=transformer_2,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
boundary_ratio=0.875,
)
elif "Wan-VACE" in args.model_type:
elif "VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
-4
View File
@@ -495,7 +495,6 @@ else:
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
@@ -515,7 +514,6 @@ else:
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageEditPipeline",
"QwenImageEditPlusPipeline",
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImagePipeline",
@@ -1151,7 +1149,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
@@ -1171,7 +1168,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
-35
View File
@@ -1064,41 +1064,6 @@ class LoraBaseMixin:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
@classmethod
def _save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
lora_metadata: Dict[str, Optional[dict]],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
"""
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
pipeline types.
"""
state_dict = {}
final_lora_adapter_metadata = {}
for prefix, layers in lora_layers.items():
state_dict.update(cls.pack_weights(layers, prefix))
for prefix, metadata in lora_metadata.items():
if metadata:
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
)
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)
+47 -39
View File
@@ -558,62 +558,70 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
_convert_to_ai_toolkit(
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
ait_sd,
"lora_unet_guidance_in_in_layer",
"time_text_embed.guidance_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_out_layer",
"time_text_embed.guidance_embedder.linear_2",
)
if any("img_in" in k for k in sds_sd):
_convert_to_ai_toolkit(
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
sds_sd,
ait_sd,
"lora_unet_img_in",
"x_embedder",
)
if any("txt_in" in k for k in sds_sd):
_convert_to_ai_toolkit(
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
sds_sd,
ait_sd,
"lora_unet_txt_in",
"context_embedder",
)
if any("time_in" in k for k in sds_sd):
_convert_to_ai_toolkit(
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
ait_sd,
"lora_unet_time_in_in_layer",
"time_text_embed.timestep_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_out_layer",
"time_text_embed.timestep_embedder.linear_2",
)
if any("vector_in" in k for k in sds_sd):
_convert_to_ai_toolkit(
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
ait_sd,
"lora_unet_vector_in_in_layer",
"time_text_embed.text_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_out_layer",
"time_text_embed.text_embedder.linear_2",
)
if any("final_layer" in k for k in sds_sd):
+266 -180
View File
@@ -510,28 +510,35 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
if unet_lora_layers:
lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers:
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if not lora_layers:
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
if unet_lora_adapter_metadata:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
cls._save_lora_weights(
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -997,34 +1004,44 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if unet_lora_layers:
lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if not lora_layers:
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
cls._save_lora_weights(
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if unet_lora_adapter_metadata is not None:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -1450,34 +1467,46 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if not lora_layers:
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
cls._save_lora_weights(
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -1801,24 +1830,28 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -2402,28 +2435,37 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
cls._save_lora_weights(
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3212,24 +3254,28 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3548,24 +3594,28 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3888,24 +3938,28 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4226,24 +4280,28 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4566,24 +4624,28 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4907,24 +4969,28 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -5318,24 +5384,28 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5732,24 +5802,28 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6070,24 +6144,28 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6410,24 +6488,28 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -6753,24 +6835,28 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
state_dict = {}
lora_adapter_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
cls._save_lora_weights(
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+25 -2
View File
@@ -17,10 +17,11 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available, is_torch_version
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
@@ -31,6 +32,7 @@ ACT2CLS = {
"gelu": nn.GELU,
"relu": nn.ReLU,
}
KERNELS_REPO_ID = "kernels-community/activation"
def get_activation(act_fn: str) -> nn.Module:
@@ -90,6 +92,27 @@ class GELU(nn.Module):
return hidden_states
# TODO: validation checks / consider making Python classes of activations like `transformers`
# All of these are temporary for now.
class CUDAOptimizedGELU(GELU):
def __init__(self, *args, **kwargs):
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
approximate = kwargs.get("approximate", "none")
super().__init__(*args, **kwargs)
if approximate == "none":
self.act_fn = activation.layers.Gelu()
elif approximate == "tanh":
self.act_fn = activation.layers.GeluTanh()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.act_fn(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
+2 -2
View File
@@ -241,7 +241,7 @@ class AttentionModuleMixin:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.ops.memory_efficient_attention(q, q, q)
_ = xops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+22 -45
View File
@@ -19,7 +19,6 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
logger = logging.get_logger(__name__)
@@ -115,8 +114,6 @@ class AutoModel(ConfigMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
Whether to trust remote code
<Tip>
@@ -143,22 +140,22 @@ class AutoModel(ConfigMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"token": token,
"local_files_only": local_files_only,
"revision": revision,
}
library = None
orig_class_name = None
@@ -192,35 +189,15 @@ class AutoModel(ConfigMixin):
else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
if has_remote_code and trust_remote_code:
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
model_cls = get_class_from_dynamic_module(
pretrained_model_or_path,
subfolder=subfolder,
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
else:
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
@@ -617,7 +617,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.size(0) > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
is_residual=is_residual,
)
self.spatial_compression_ratio = scale_factor_spatial
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -1145,13 +1145,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
+27 -5
View File
@@ -20,11 +20,20 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import is_torch_npu_available, is_torch_version
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.kernels_utils import use_kernel_forward_from_hub
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
silu_kernel = activation.layers.Silu
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
@@ -508,6 +529,7 @@ else:
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
@use_kernel_forward_from_hub("RMSNorm")
class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices
def forward(self, latent) -> torch.Tensor:
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
):
residual = hidden_states
attention_kwargs = attention_kwargs or {}
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
):
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
):
"""
Perform a forward pass through the LuminaNextDiTBlock.
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> torch.Tensor:
"""
Forward pass of LuminaNextDiT.
@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`BriaTransformer2DModel`] forward method.
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Tuple, Union
from typing import Dict, Union
import torch
import torch.nn as nn
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
# 1. Timestep conditioning
(
norm_hidden_states,
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -22,7 +22,8 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
gelu_tanh_kernel = activation.layers.GeluTanh
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
@@ -300,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
if DIFFUSERS_ENABLE_HUB_KERNELS:
from ..normalization import RMSNorm
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
@@ -312,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
if DIFFUSERS_ENABLE_HUB_KERNELS:
from ..normalization import RMSNorm
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
@@ -351,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
# self.act_mlp = nn.GELU(approximate="tanh")
# else:
# self.act_mlp = gelu_tanh_kernel()
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.attn = FluxAttention(
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent) -> torch.Tensor:
def forward(self, latent):
latent = self.proj(latent)
return latent
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
wtype = hidden_states.dtype
(
shift_msa_i,
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
+2 -11
View File
@@ -82,7 +82,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
time_embedding_dim: Optional[int] = None,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
@@ -101,23 +100,15 @@ class UNet1DModel(ModelMixin, ConfigMixin):
# time
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
@@ -25,7 +25,6 @@ from ..utils import (
is_accelerate_available,
logging,
)
from ..utils.torch_utils import get_device
if is_accelerate_available():
@@ -162,9 +161,7 @@ class AutoOffloadStrategy:
current_module_size = model.get_memory_footprint()
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
@@ -304,7 +301,7 @@ class ComponentsManager:
cm.add("vae", vae_model, collection="sdxl")
# Enable auto offloading
cm.enable_auto_cpu_offload()
cm.enable_auto_cpu_offload(device="cuda")
# Retrieve components
unet = cm.get_one(name="unet", collection="sdxl")
@@ -493,8 +490,6 @@ class ComponentsManager:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.xpu.is_available():
torch.xpu.empty_cache()
# YiYi TODO: rename to search_components for now, may remove this method
def search_components(
@@ -683,7 +678,7 @@ class ComponentsManager:
return get_return_dict(matches, return_dict_with_names)
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
"""
Enable automatic CPU offloading for all components.
@@ -709,8 +704,6 @@ class ComponentsManager:
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
if device is None:
device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
@@ -323,7 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
-4
View File
@@ -285,7 +285,6 @@ else:
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -394,7 +393,6 @@ else:
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImageEditPipeline",
"QwenImageEditPlusPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
@@ -684,7 +682,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -722,7 +719,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
@@ -688,11 +688,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -749,12 +749,12 @@ class ChromaImg2ImgPipeline(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale (`float`, *optional*, defaults to 5.0):
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
strength (`float, *optional*, defaults to 0.9):
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
-47
View File
@@ -1,47 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_lucy_edit import LucyEditPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,735 +0,0 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The Decart AI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications by Decart AI Team:
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
import html
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import regex as re
import torch
from PIL import Image
from transformers import AutoTokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LucyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> from typing import List
>>> import torch
>>> from PIL import Image
>>> from diffusers import AutoencoderKLWan, LucyEditPipeline
>>> from diffusers.utils import export_to_video, load_video
>>> # Arguments
>>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
>>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
>>> negative_prompt = ""
>>> num_frames = 81
>>> height = 480
>>> width = 832
>>> # Load video
>>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
... video = load_video(url)[:num_frames]
... video = [video[i].resize((width, height)) for i in range(num_frames)]
... return video
>>> video = load_video(url, convert_method=convert_video)
>>> # Load model
>>> model_id = "decart-ai/Lucy-Edit-Dev"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> # Generate video
>>> output = pipe(
... prompt=prompt,
... video=video,
... negative_prompt=negative_prompt,
... height=480,
... width=832,
... num_frames=81,
... guidance_scale=5.0,
... ).frames[0]
>>> # Export video
>>> export_to_video(output, "output.mp4", fps=24)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for video-to-video generation using Lucy Edit.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer_2 ([`WanTransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
stages. If not provided, only `transformer` is used.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: Optional[WanTransformer3DModel] = None,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
transformer_2=transformer_2,
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
video,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if video is None:
raise ValueError("`video` is required, received None.")
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_latent_frames = (
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
)
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
# Prepare noise latents
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# Prepare condition latents
condition_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
]
condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
device, dtype
)
condition_latents = (condition_latents - latents_mean) * latents_std
# Check shapes
assert latents.shape == condition_latents.shape, (
f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
)
return latents, condition_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: List[Image.Image],
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
video (`List[Image.Image]`):
The video to use as the condition for the video generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `81`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:
Returns:
[`~LucyPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
video,
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = (
self.transformer.config.out_channels
if self.transformer is not None
else self.transformer_2.config.out_channels
)
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition_latents = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
torch.float32,
device,
generator,
latents,
)
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
# latent_model_input = latents.to(transformer_dtype)
latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
# latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LucyPipelineOutput(frames=video)
@@ -1,20 +0,0 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LucyPipelineOutput(BaseOutput):
r"""
Output class for Lucy pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -48,12 +48,10 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
if is_transformers_version("<=", "4.56.2"):
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -114,9 +112,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -195,9 +191,7 @@ def filter_model_files(filenames):
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -218,9 +212,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -28,7 +28,6 @@ else:
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
@@ -44,7 +43,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else:
@@ -208,6 +208,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.vl_processor = processor
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
@@ -1,883 +0,0 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from PIL import Image
>>> from diffusers import QwenImageEditPlusPipeline
>>> from diffusers.utils import load_image
>>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
... ).convert("RGB")
>>> prompt = (
... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
... )
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
>>> image.save("qwenimage_edit_plus.png")
```
"""
CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
r"""
The Qwen-Image-Edit pipeline for image editing.
Args:
transformer ([`QwenImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`QwenTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLQwenImage,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
processor: Qwen2VLProcessor,
transformer: QwenImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
processor=processor,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 64
self.default_sample_size = 128
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
if isinstance(image, list):
base_img_prompt = ""
for i, img in enumerate(image):
base_img_prompt += img_prompt_template.format(i + 1)
elif image is not None:
base_img_prompt = img_prompt_template.format(1)
else:
base_img_prompt = ""
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]
model_inputs = self.processor(
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, encoder_attention_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
image_latents = (image_latents - latents_mean) / latents_std
return image_latents
def prepare_latents(
self,
images,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, 1, num_channels_latents, height, width)
image_latents = None
if images is not None:
if not isinstance(images, list):
images = [images]
all_image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
all_image_latents.append(image_latents)
image_latents = torch.cat(all_image_latents, dim=1)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
else:
latents = latents.to(device=device, dtype=dtype)
return latents, image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: Optional[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
true_cfg_scale: float = 4.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
lower image quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[-1].size if isinstance(image, list) else image.size
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
multiple_of = self.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
if not isinstance(image, list):
image = [image]
condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []
for img in image:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(
CONDITION_IMAGE_SIZE, image_width / image_height
)
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
condition_image_sizes.append((condition_width, condition_height))
vae_image_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
image=condition_images,
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
image=condition_images,
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, image_latents = self.prepare_latents(
vae_images,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
*[
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
for vae_width, vae_height in vae_image_sizes
],
]
] * batch_size
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return QwenImagePipelineOutput(images=image)
@@ -152,26 +152,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
def __init__(
self,
@@ -180,8 +170,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
super().__init__()
@@ -190,10 +178,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
transformer_2=transformer_2,
scheduler=scheduler,
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -334,7 +321,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video=None,
mask=None,
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if height % base != 0 or width % base != 0:
@@ -346,8 +332,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -683,7 +667,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -745,10 +728,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -814,7 +793,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video,
mask,
reference_images,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -824,11 +802,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -922,53 +896,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model(
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -1592,21 +1592,6 @@ class LTXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LucyEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1892,21 +1877,6 @@ class QwenImageEditPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class QwenImageEditPlusPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -247,7 +247,6 @@ def find_pipeline_class(loaded_module):
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
@@ -354,7 +353,6 @@ def get_cached_module_file(
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -412,7 +410,6 @@ def get_cached_module_file(
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -427,7 +424,6 @@ def get_cached_module_file(
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
subfolder: Optional[str] = None,
class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -501,7 +497,6 @@ def get_class_from_dynamic_module(
final_module = get_cached_module_file(
pretrained_model_name_or_path,
module_file,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
+41
View File
@@ -1,3 +1,5 @@
from typing import Union
from ..utils import get_logger
from .import_utils import is_kernels_available
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
if is_kernels_available():
from kernels import (
Device,
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"RMSNorm": {
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
},
}
register_kernel_mapping(_KERNEL_MAPPING)
else:
# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
def decorator(cls):
return cls
return decorator
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
)
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
+1
View File
@@ -43,6 +43,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
+2
View File
@@ -21,6 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
@@ -43,6 +44,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = {
"num_attention_heads": 4,
+19 -17
View File
@@ -50,6 +50,7 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -123,29 +124,30 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
+4 -9
View File
@@ -55,8 +55,9 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
@@ -281,8 +282,9 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 8,
@@ -905,13 +907,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
def test_flux_kohya_embedders_conversion(self):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
+1 -1
View File
@@ -51,6 +51,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -253,7 +254,6 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
expected_slices = Expectations(
{
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
("xpu", 3): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
}
)
# fmt: on
+1
View File
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
+25 -21
View File
@@ -39,6 +39,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -140,30 +141,33 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False,
)
def test_lora_fuse_nan(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
self.assertTrue(np.isnan(out).all())
+1
View File
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
+1
View File
@@ -37,6 +37,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
+3 -2
View File
@@ -31,8 +31,9 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0}
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
+1
View File
@@ -55,6 +55,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
+1
View File
@@ -42,6 +42,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
+3 -1
View File
@@ -50,6 +50,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -164,8 +165,9 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
+1160 -1081
View File
File diff suppressed because it is too large Load Diff
@@ -35,13 +35,13 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
+171 -153
View File
@@ -88,12 +88,7 @@ from ..testing_utils import (
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
def caculate_expected_num_shards(index_map_path):
@@ -1118,6 +1113,177 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator
def test_cpu_offload(self):
if self.model_class._no_split_modules is None:
@@ -1775,154 +1941,6 @@ class ModelTesterMixin:
_ = loaded_model(**inputs_dict)
class PEFTTesterMixin:
@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
state_dict_loaded = safetensors.torch.load_file(model_file)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k, loaded_v in state_dict_loaded.items():
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
@require_peft_backend
def test_lora_wrong_adapter_name_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError, match=rf"Adapter name {wrong_name} not found in the model\."):
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@require_peft_backend
def test_lora_adapter_wrong_metadata_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
# Perturb the metadata
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in list(lora_adapter_metadata.items()):
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with pytest.raises(TypeError, match=r"`LoraConfig` class could not be instantiated"):
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
@@ -30,13 +30,13 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PriorTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PriorTransformer
main_input_name = "hidden_states"
@@ -20,13 +20,13 @@ import torch
from diffusers import AuraFlowTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class AuraFlowTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
@@ -22,12 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -83,7 +78,7 @@ def create_bria_ip_adapter_state_dict(model):
return ip_state_dict
class BriaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
@@ -22,12 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -83,7 +78,7 @@ def create_chroma_ip_adapter_state_dict(model):
return ip_state_dict
class ChromaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
@@ -19,14 +19,17 @@ import torch
from diffusers import CogVideoXTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -19,13 +19,13 @@ import torch
from diffusers import CogView4Transformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -19,14 +19,17 @@ import torch
from diffusers import ConsisIDTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class ConsisIDTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ConsisIDTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -22,12 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -85,7 +80,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
@@ -19,14 +19,17 @@ import torch
from diffusers import HiDreamImageTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HiDreamTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = HiDreamImageTransformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@@ -18,14 +18,17 @@ import torch
from diffusers import HunyuanVideoTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -18,14 +18,17 @@ import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoFramepackTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -20,13 +20,13 @@ import torch
from diffusers import LTXVideoTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -19,14 +19,17 @@ import torch
from diffusers import Lumina2Transformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = Lumina2Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -20,13 +20,13 @@ import torch
from diffusers import MochiTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class MochiTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -21,13 +21,13 @@ import torch
from diffusers import QwenImageTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
@@ -18,14 +18,17 @@ import torch
from diffusers import SanaTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SanaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -24,13 +24,13 @@ from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@@ -18,14 +18,17 @@ import torch
from diffusers import SkyReelsV2Transformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class SkyReelsV2Transformer3DTests(ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = SkyReelsV2Transformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -18,14 +18,17 @@ import torch
from diffusers import WanTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@@ -55,7 +55,6 @@ from ...testing_utils import (
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
UNetTesterMixin,
)
@@ -355,7 +354,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
@@ -1084,6 +1083,48 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMix
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
@@ -30,7 +30,7 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetMotionModelTests(ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetMotionModel
main_input_name = "sample"
@@ -48,7 +48,6 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_xformers_attention = False
required_optional_params = frozenset(
[
"num_inference_steps",
@@ -47,8 +47,8 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
test_xformers_attention = False
test_layerwise_casting = True
supports_dduf = False
@@ -18,13 +18,11 @@ import random
import unittest
import numpy as np
import pytest
import torch
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -217,9 +215,6 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummy = Dummies()
return dummy.get_dummy_inputs(device=device, seed=seed)
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky(self):
device = "cpu"
@@ -16,10 +16,8 @@
import unittest
import numpy as np
import pytest
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
from diffusers.utils import is_transformers_version
from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -75,9 +73,6 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
)
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky(self):
device = "cpu"
@@ -186,9 +181,6 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky(self):
device = "cpu"
@@ -300,9 +292,6 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky(self):
device = "cpu"
@@ -18,7 +18,6 @@ import random
import unittest
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
@@ -32,7 +31,6 @@ from diffusers import (
VQModel,
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -239,9 +237,6 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky_img2img(self):
device = "cpu"
@@ -18,14 +18,12 @@ import random
import unittest
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -233,9 +231,6 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"), reason="Latest transformers changes the slices", strict=True
)
def test_kandinsky_inpaint(self):
device = "cpu"

Some files were not shown because too many files have changed in this diff Show More