Compare commits

..

20 Commits

Author SHA1 Message Date
DN6 e3f6ba3fc1 Merge branch 'main' into custom-code-updates 2025-08-11 11:55:57 +05:30
Dhruv Nair fb8722e9ab update 2025-08-09 16:00:24 +02:00
Dhruv Nair 512044c5ea update 2025-08-09 15:06:18 +02:00
Dhruv Nair 6c85fcd899 update 2025-08-08 18:52:55 +02:00
Dhruv Nair 085e9cba36 update 2025-08-08 18:22:44 +02:00
DN6 919ee1aee3 Merge branch 'main' into custom-code-updates 2025-08-08 19:53:30 +05:30
DN6 9cda45701c Merge branch 'main' into custom-code-updates 2025-08-08 19:48:18 +05:30
DN6 c678e8a445 update 2025-08-08 19:47:50 +05:30
Dhruv Nair d1342d7464 update 2025-08-08 12:10:06 +02:00
Dhruv Nair 9a0cc463ee update 2025-08-06 19:32:23 +02:00
Dhruv Nair ef4e373a65 Merge branch 'main' into custom-code-updates 2025-08-06 19:31:05 +02:00
Dhruv Nair 1b4af6b7ef update 2025-08-06 17:43:21 +02:00
DN6 ea77fdc4b4 update 2025-08-06 17:17:51 +05:30
Dhruv Nair 255c5742aa update 2025-07-30 08:33:51 +02:00
Dhruv Nair 4524d43279 update 2025-07-30 08:24:25 +02:00
Dhruv Nair b6dc0b75f4 Merge branch 'custom-code-updates' of https://github.com/huggingface/diffusers into custom-code-updates 2025-07-30 08:19:38 +02:00
Dhruv Nair 966a2ff8df update 2025-07-29 21:06:40 +02:00
YiYi Xu 201da97dd0 Merge branch 'main' into custom-code-updates 2025-07-23 10:23:35 -10:00
DN6 4423097b23 update 2025-07-22 19:31:22 +05:30
Dhruv Nair 60d1b81023 update 2025-07-21 18:44:44 +02:00
64 changed files with 256 additions and 249 deletions
-2
View File
@@ -25,8 +25,6 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
</Tip>
Flux comes in the following variants:
+1 -1
View File
@@ -18,7 +18,7 @@
<Tip>
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
+1 -1
View File
@@ -88,7 +88,7 @@ export_to_video(video, "output.mp4", fps=24)
</hfoption>
<hfoption id="inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
```py
import torch
+1 -1
View File
@@ -20,7 +20,7 @@ Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn
<Tip>
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
+1 -1
View File
@@ -119,7 +119,7 @@ export_to_video(output, "output.mp4", fps=16)
</hfoption>
<hfoption id="T2V inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
```py
# pip install ftfy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
-1
View File
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
-1
View File
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -13,7 +13,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import typing
@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,7 +12,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -13,7 +13,6 @@
# limitations under the License.
import functools
import math
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -163,7 +162,7 @@ class QwenEmbedRope(nn.Module):
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
pos_freqs = torch.cat(
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
@@ -171,7 +170,7 @@ class QwenEmbedRope(nn.Module):
],
dim=1,
)
neg_freqs = torch.cat(
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,8 +179,6 @@ class QwenEmbedRope(nn.Module):
dim=1,
)
self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
# 是否使用 scale rope
self.scale_rope = scale_rope
@@ -201,17 +198,33 @@ class QwenEmbedRope(nn.Module):
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
frame, height, width = video_fhw
rope_key = f"{frame}_{height}_{width}"
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
vid_freqs = self.rope_cache[rope_key]
else:
vid_freqs = self._compute_video_freqs(frame, height, width)
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs = self.rope_cache[rope_key]
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
@@ -223,25 +236,6 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenDoubleStreamAttnProcessor2_0:
"""
@@ -488,7 +482,6 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
_supports_gradient_checkpointing = True
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
@register_to_config
def __init__(
+11 -8
View File
@@ -7,15 +7,9 @@ from ..utils import (
get_objects_from_module,
is_torch_available,
is_transformers_available,
logging,
)
logger = logging.get_logger(__name__)
logger.warning(
"Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
@@ -67,8 +61,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
InsertableDict,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys
-1
View File
@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
-1
View File
@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
-5
View File
@@ -1711,11 +1711,6 @@ class ModelTesterMixin:
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
if self.model_class.__name__ == "QwenImageTransformer2DModel":
pytest.skip(
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
)
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
@@ -1,101 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (16, 16)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 7
vae_scale_factor = 4
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
@@ -124,22 +124,37 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
}
return inputs
def test_inference(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
def test_stable_diffusion_3_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images[0]
generated_slice = image.flatten()
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
output_same_prompt = pipe(**inputs).images[0]
# fmt: off
expected_slice = np.array([0.5112, 0.5228, 0.5235, 0.5524, 0.3188, 0.5017, 0.5574, 0.4899, 0.6812, 0.5991, 0.3908, 0.5213, 0.5582, 0.4457, 0.4204, 0.5616])
# fmt: on
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
inputs["prompt_3"] = "another different prompt"
output_different_prompts = pipe(**inputs).images[0]
self.assertTrue(
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
)
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_stable_diffusion_3_different_negative_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt_2"] = "deformed"
inputs["negative_prompt_3"] = "blurry"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -253,9 +268,40 @@ class StableDiffusion3PipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array([0.4648, 0.4404, 0.4177, 0.5063, 0.4800, 0.4287, 0.5425, 0.5190, 0.4717, 0.5430, 0.5195, 0.4766, 0.5361, 0.5122, 0.4612, 0.4871, 0.4749, 0.4058, 0.4756, 0.4678, 0.3804, 0.4832, 0.4822, 0.3799, 0.5103, 0.5034, 0.3953, 0.5073, 0.4839, 0.3884])
# fmt: on
expected_slice = np.array(
[
0.4648,
0.4404,
0.4177,
0.5063,
0.4800,
0.4287,
0.5425,
0.5190,
0.4717,
0.5430,
0.5195,
0.4766,
0.5361,
0.5122,
0.4612,
0.4871,
0.4749,
0.4058,
0.4756,
0.4678,
0.3804,
0.4832,
0.4822,
0.3799,
0.5103,
0.5034,
0.3953,
0.5073,
0.4839,
0.3884,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
@@ -128,22 +128,37 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
}
return inputs
def test_inference(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
def test_stable_diffusion_3_img2img_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images[0]
generated_slice = image.flatten()
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
output_same_prompt = pipe(**inputs).images[0]
# fmt: off
expected_slice = np.array([0.4564, 0.5486, 0.4868, 0.5923, 0.3775, 0.5543, 0.4807, 0.4177, 0.3778, 0.5957, 0.5726, 0.4333, 0.6312, 0.5062, 0.4838, 0.5984])
# fmt: on
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
inputs["prompt_3"] = "another different prompt"
output_different_prompts = pipe(**inputs).images[0]
self.assertTrue(
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
)
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_stable_diffusion_3_img2img_different_negative_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt_2"] = "deformed"
inputs["negative_prompt_3"] = "blurry"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
@unittest.skip("Skip for now.")
def test_multi_vae(self):
@@ -192,16 +207,112 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slices = Expectations(
{
("xpu", 3): np.array([0.5117, 0.4421, 0.3852, 0.5044, 0.4219, 0.3262, 0.5024, 0.4329, 0.3276, 0.4978, 0.4412, 0.3355, 0.4983, 0.4338, 0.3279, 0.4893, 0.4241, 0.3129, 0.4875, 0.4253, 0.3030, 0.4961, 0.4267, 0.2988, 0.5029, 0.4255, 0.3054, 0.5132, 0.4248, 0.3222]),
("cuda", 7): np.array([0.5435, 0.4673, 0.5732, 0.4438, 0.3557, 0.4912, 0.4331, 0.3491, 0.4915, 0.4287, 0.347, 0.4849, 0.4355, 0.3469, 0.4871, 0.4431, 0.3538, 0.4912, 0.4521, 0.3643, 0.5059, 0.4587, 0.373, 0.5166, 0.4685, 0.3845, 0.5264, 0.4746, 0.3914, 0.5342]),
("cuda", 8): np.array([0.5146, 0.4385, 0.3826, 0.5098, 0.4150, 0.3218, 0.5142, 0.4312, 0.3298, 0.5127, 0.4431, 0.3411, 0.5171, 0.4424, 0.3374, 0.5088, 0.4348, 0.3242, 0.5073, 0.4380, 0.3174, 0.5132, 0.4397, 0.3115, 0.5132, 0.4343, 0.3118, 0.5219, 0.4328, 0.3256]),
("xpu", 3): np.array(
[
0.5117,
0.4421,
0.3852,
0.5044,
0.4219,
0.3262,
0.5024,
0.4329,
0.3276,
0.4978,
0.4412,
0.3355,
0.4983,
0.4338,
0.3279,
0.4893,
0.4241,
0.3129,
0.4875,
0.4253,
0.3030,
0.4961,
0.4267,
0.2988,
0.5029,
0.4255,
0.3054,
0.5132,
0.4248,
0.3222,
]
),
("cuda", 7): np.array(
[
0.5435,
0.4673,
0.5732,
0.4438,
0.3557,
0.4912,
0.4331,
0.3491,
0.4915,
0.4287,
0.347,
0.4849,
0.4355,
0.3469,
0.4871,
0.4431,
0.3538,
0.4912,
0.4521,
0.3643,
0.5059,
0.4587,
0.373,
0.5166,
0.4685,
0.3845,
0.5264,
0.4746,
0.3914,
0.5342,
]
),
("cuda", 8): np.array(
[
0.5146,
0.4385,
0.3826,
0.5098,
0.4150,
0.3218,
0.5142,
0.4312,
0.3298,
0.5127,
0.4431,
0.3411,
0.5171,
0.4424,
0.3374,
0.5088,
0.4348,
0.3242,
0.5073,
0.4380,
0.3174,
0.5132,
0.4397,
0.3115,
0.5132,
0.4343,
0.3118,
0.5219,
0.4328,
0.3256,
]
),
}
)
# fmt: on
expected_slice = expected_slices.get_expectation()
@@ -132,23 +132,37 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
}
return inputs
def test_inference(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
def test_stable_diffusion_3_inpaint_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images[0]
generated_slice = image.flatten()
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
output_same_prompt = pipe(**inputs).images[0]
# fmt: off
expected_slice = np.array([0.5035, 0.6661, 0.5859, 0.413, 0.4224, 0.4234, 0.7181, 0.5062, 0.5183, 0.6877, 0.5074, 0.585, 0.6111, 0.5422, 0.5306, 0.5891])
# fmt: on
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
inputs["prompt_3"] = "another different prompt"
output_different_prompts = pipe(**inputs).images[0]
self.assertTrue(
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
)
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt_2"] = "deformed"
inputs["negative_prompt_3"] = "blurry"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
@unittest.skip("Skip for now.")
def test_multi_vae(self):
pass