Compare commits

..

3 Commits

Author SHA1 Message Date
DN6 532b395718 update 2025-07-17 21:56:48 +05:30
DN6 5c43924ac2 update 2025-07-17 19:57:45 +05:30
DN6 a633289e10 update 2025-07-16 19:41:48 +05:30
20 changed files with 173 additions and 1254 deletions
-141
View File
@@ -1,141 +0,0 @@
name: Fast PR tests for Modular
on:
pull_request:
branches: [main]
paths:
- "src/diffusers/modular_pipelines/**.py"
- "src/diffusers/models/modeling_utils.py"
- "src/diffusers/models/model_loading_utils.py"
- "src/diffusers/pipelines/pipeline_utils.py"
- "src/diffusers/pipeline_loading_utils.py"
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/modular_pipelines/**.py"
- ".github/**.yml"
- "utils/**.py"
- "setup.py"
push:
branches:
- ci-*
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
+1 -4
View File
@@ -763,7 +763,4 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping # resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls) remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls: return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
else:
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
+4 -1
View File
@@ -24,7 +24,7 @@ from typing_extensions import Self
from .. import __version__ from .. import __version__
from ..quantizers import DiffusersAutoQuantizer from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
from .single_file_utils import ( from .single_file_utils import (
SingleFileComponentError, SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers, convert_animatediff_checkpoint_to_diffusers,
@@ -431,7 +431,10 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
) )
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
+7 -1
View File
@@ -46,7 +46,7 @@ from ..utils import (
) )
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_transformers_available(): if is_transformers_available():
@@ -1690,7 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
model.load_state_dict(diffusers_format_checkpoint, strict=False) model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2150,7 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)
+3 -1
View File
@@ -19,7 +19,7 @@ from ..models.embeddings import (
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_accelerate_available(): if is_accelerate_available():
@@ -82,6 +82,7 @@ class FluxTransformer2DLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_projection return image_projection
@@ -157,6 +158,7 @@ class FluxTransformer2DLoadersMixin:
key_id += 1 key_id += 1
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
+3 -1
View File
@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -82,6 +82,7 @@ class SD3Transformer2DLoadersMixin:
) )
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
@@ -151,6 +152,7 @@ class SD3Transformer2DLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_proj return image_proj
+3 -1
View File
@@ -43,7 +43,7 @@ from ..utils import (
is_torch_version, is_torch_version,
logging, logging,
) )
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
from .lora_base import _func_optionally_disable_offloading from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
@@ -755,6 +755,7 @@ class UNet2DConditionLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_projection return image_projection
@@ -853,6 +854,7 @@ class UNet2DConditionLoadersMixin:
key_id += 2 key_id += 2
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
+5 -7
View File
@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
load_or_create_model_card, load_or_create_model_card,
populate_model_card, populate_model_card,
) )
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import device_synchronize, empty_device_cache
from .model_loading_utils import ( from .model_loading_utils import (
_caching_allocator_warmup, _caching_allocator_warmup,
_determine_device_map, _determine_device_map,
@@ -1540,7 +1540,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0: if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder) save_offload_index(offload_index, offload_folder)
@@ -1877,9 +1880,4 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping # resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls) remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls: return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
return super(LegacyModelMixin, remapped_class).from_pretrained(
pretrained_model_name_or_path, **kwargs_copy
)
else:
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
@@ -323,6 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
""" """
config_name = "config.json" config_name = "config.json"
model_name = None
@classmethod @classmethod
def _get_signature_keys(cls, obj): def _get_signature_keys(cls, obj):
@@ -333,6 +334,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters return expected_modules, optional_parameters
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@@ -358,7 +367,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code, pretrained_model_name_or_path, has_remote_code trust_remote_code, pretrained_model_name_or_path, has_remote_code
) )
if not (has_remote_code and trust_remote_code): if not (has_remote_code and trust_remote_code):
raise ValueError("TODO") raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__] class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".") module_file, class_name = class_ref.split(".")
@@ -367,7 +378,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
module_file=module_file, module_file=module_file,
class_name=class_name, class_name=class_name,
is_modular=True,
**hub_kwargs, **hub_kwargs,
**kwargs, **kwargs,
) )
@@ -479,22 +489,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return list(combined_dict.values()) return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]
@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]
@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]
@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks): class PipelineBlock(ModularPipelineBlocks):
""" """
@@ -2841,8 +2835,3 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
type_hint=type_hint, type_hint=type_hint,
**spec_dict, **spec_dict,
) )
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -93,7 +93,7 @@ class ComponentSpec:
config: Optional[FrozenDict] = None config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default=None, metadata={"loading": True}) subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
@@ -744,6 +744,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=None, timestep=None,
is_strength_max=True, is_strength_max=True,
add_noise=True, add_noise=True,
return_noise=False,
return_image_latents=False,
): ):
shape = ( shape = (
batch_size, batch_size,
@@ -766,7 +768,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if image.shape[1] == 4: if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype) image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif latents is None and not is_strength_max: elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
@@ -784,7 +786,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device) latents = image_latents.to(device)
outputs = (latents, noise, image_latents) outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs return outputs
@@ -856,7 +864,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( block_state.latents, block_state.noise = self.prepare_latents_inpaint(
components, components,
block_state.batch_size * block_state.num_images_per_prompt, block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents, components.num_channels_latents,
@@ -870,6 +878,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=block_state.latent_timestep, timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max, is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise, add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
) )
# 7. Prepare mask latent variables # 7. Prepare mask latent variables
@@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
CLIPTextModel, CLIPTextModel,
@@ -37,13 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import ( from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
AutoencoderKL,
ControlNetUnionModel,
ImageProjection,
MultiControlNetUnionModel,
UNet2DConditionModel,
)
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
@@ -267,9 +262,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet: Union[ controlnet: ControlNetUnionModel,
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False, requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
@@ -279,8 +272,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
): ):
super().__init__() super().__init__()
if isinstance(controlnet, (list, tuple)): if not isinstance(controlnet, ControlNetUnionModel):
controlnet = MultiControlNetUnionModel(controlnet) raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
self.register_modules( self.register_modules(
vae=vae, vae=vae,
@@ -656,7 +649,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
control_mode=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
@@ -730,44 +722,28 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetUnionModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image` # Check `image`
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
if isinstance(controlnet, ControlNetUnionModel): )
for image_ in image: if (
self.check_image(image_, prompt, prompt_embeds) isinstance(self.controlnet, ControlNetModel)
elif isinstance(controlnet, MultiControlNetUnionModel): or is_compiled
if not isinstance(image, list): and isinstance(self.controlnet._orig_mod, ControlNetModel)
raise TypeError("For multiple controlnets: `image` must be type `list`") ):
elif not all(isinstance(i, list) for i in image): self.check_image(image, prompt, prompt_embeds)
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") elif (
elif len(image) != len(self.controlnet.nets): isinstance(self.controlnet, ControlNetUnionModel)
raise ValueError( or is_compiled
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
) ):
self.check_image(image, prompt, prompt_embeds)
for images_ in image: else:
for image_ in images_: assert False
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)): if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start] control_guidance_start = [control_guidance_start]
if isinstance(controlnet, MultiControlNetUnionModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
if not isinstance(control_guidance_end, (tuple, list)): if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end] control_guidance_end = [control_guidance_end]
@@ -786,15 +762,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if end > 1.0: if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Check `control_mode`
if isinstance(controlnet, ControlNetUnionModel):
if max(control_mode) >= controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
elif isinstance(controlnet, MultiControlNetUnionModel):
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1082,7 +1049,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, control_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.8, strength: float = 0.8,
@@ -1107,7 +1074,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
@@ -1137,13 +1104,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again. image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): control_image (`PipelineImageInput`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
images must be passed as a list such that each element of the list can be correctly batched for input init, images must be passed as a list such that each element of the list can be correctly batched for
to a single ControlNet. input to a single controlnet.
height (`int`, *optional*, defaults to the size of control_image): height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1217,21 +1184,16 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
the corresponding scale as a list. corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying. The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying. The percentage of total steps at which the controlnet stops applying.
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
where each ControlNet should have its corresponding control mode list. Should reflect the order of
conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1311,6 +1273,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list): if not isinstance(control_image, list):
control_image = [control_image] control_image = [control_image]
else: else:
@@ -1319,56 +1287,37 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if not isinstance(control_mode, list): if not isinstance(control_mode, list):
control_mode = [control_mode] control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel): if len(control_image) != len(control_mode):
control_image = [[item] for item in control_image] raise ValueError("Expected len(control_image) == len(control_type)")
control_mode = [[item] for item in control_mode]
# align format for control guidance num_control_type = controlnet.config.num_control_type
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs # 1. Check inputs
self.check_inputs( control_type = [0 for _ in range(num_control_type)]
prompt, for _image, control_idx in zip(control_image, control_mode):
prompt_2, control_type[control_idx] = 1
control_image, self.check_inputs(
strength, prompt,
num_inference_steps, prompt_2,
callback_steps, _image,
negative_prompt, strength,
negative_prompt_2, num_inference_steps,
prompt_embeds, callback_steps,
negative_prompt_embeds, negative_prompt,
pooled_prompt_embeds, negative_prompt_2,
negative_pooled_prompt_embeds, prompt_embeds,
ip_adapter_image, negative_prompt_embeds,
ip_adapter_image_embeds, pooled_prompt_embeds,
controlnet_conditioning_scale, negative_pooled_prompt_embeds,
control_guidance_start, ip_adapter_image,
control_guidance_end, ip_adapter_image_embeds,
control_mode, controlnet_conditioning_scale,
callback_on_step_end_tensor_inputs, control_guidance_start,
) control_guidance_end,
callback_on_step_end_tensor_inputs,
)
if isinstance(controlnet, ControlNetUnionModel): control_type = torch.Tensor(control_type)
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
]
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
@@ -1385,11 +1334,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
device = self._execution_device device = self._execution_device
global_pool_conditions = ( global_pool_conditions = controlnet.config.global_pool_conditions
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetUnionModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt # 3.1. Encode input prompt
@@ -1427,55 +1372,22 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
self.do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 4.1 Prepare image # 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
# 4.2 Prepare control images for idx, _ in enumerate(control_image):
if isinstance(controlnet, ControlNetUnionModel): control_image[idx] = self.prepare_control_image(
control_images = [] image=control_image[idx],
width=width,
for image_ in control_image: height=height,
image_ = self.prepare_control_image( batch_size=batch_size * num_images_per_prompt,
image=image_, num_images_per_prompt=num_images_per_prompt,
width=width, device=device,
height=height, dtype=controlnet.dtype,
batch_size=batch_size * num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt, guess_mode=guess_mode,
device=device, )
dtype=controlnet.dtype, height, width = control_image[idx].shape[-2:]
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(image_)
control_image = control_images
height, width = control_image[0].shape[-2:]
elif isinstance(controlnet, MultiControlNetUnionModel):
control_images = []
for control_image_ in control_image:
images = []
for image_ in control_image_:
image_ = self.prepare_control_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
control_images.append(images)
control_image = control_images
height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1502,11 +1414,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 7.1 Create tensor stating which controlnets to keep # 7.1 Create tensor stating which controlnets to keep
controlnet_keep = [] controlnet_keep = []
for i in range(len(timesteps)): for i in range(len(timesteps)):
keeps = [ controlnet_keep.append(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1.0
for s, e in zip(control_guidance_start, control_guidance_end) - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
] )
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width) original_size = original_size or (height, width)
@@ -1549,25 +1460,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
control_type = (
control_type_repeat_factor = ( control_type.reshape(1, -1)
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) .to(device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
) )
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -383,8 +383,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma)
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * np.float64(self.scheduler.init_noise_sigma)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler # Scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * np.float64(self.scheduler.init_noise_sigma)
# 5. Add noise to image # 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64) noise_level = np.array([noise_level]).astype(np.int64)
View File
@@ -1,511 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from diffusers import (
ClassifierFreeGuidance,
ComponentsManager,
ModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers.loaders import ModularIPAdapterMixin
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
enable_full_determinism()
class SDXLModularTests:
"""
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
"""
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset(
[
"prompt",
"height",
"width",
"negative_prompt",
"cross_attention_kwargs",
"image",
"mask_image",
]
)
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_default_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == expected_image_shape
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
"Image Slice does not match expected slice"
)
class SDXLModularIPAdapterTests:
"""
This mixin is designed to test IP Adapter.
"""
def test_pipeline_inputs_and_blocks(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
assert "ip_adapter_image" in parameters, (
"`ip_adapter_image` argument must be supported by the `__call__` method"
)
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
intermediate_parameters = blocks.intermediate_input_names
assert "ip_adapter_image" not in parameters, (
"`ip_adapter_image` argument must be removed from the `__call__` method"
)
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
assert max_diff_without_adapter_scale < expected_max_diff, (
"Output without ip-adapter must be same as normal inference"
)
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
"Output without multi-ip-adapter must be same as normal inference"
)
assert max_diff_with_multi_adapter_scale > 1e-2, (
"Output with multi-ip-adapter scale must be different from normal inference"
)
class SDXLModularControlNetTests:
"""
This mixin is designed to test ControlNet.
"""
def test_pipeline_inputs(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
assert "controlnet_conditioning_scale" in parameters, (
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
)
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
controlnet_embedder_scale_factor = 2
image = torch.randn(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
device=torch_device,
)
inputs["control_image"] = image
return inputs
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for ControlNet.
The following scenarios are tested:
- Single ControlNet with scale=0 should produce same output as no ControlNet.
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet
inputs = self.get_dummy_inputs(torch_device)
output_without_controlnet = pipe(**inputs, output="images")
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but scale=0 which should have no effect
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 0.0
output_without_controlnet_scale = pipe(**inputs, output="images")
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but with scale of adapter weights
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 42.0
output_with_controlnet_scale = pipe(**inputs, output="images")
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
assert max_diff_without_controlnet_scale < expected_max_diff, (
"Output without controlnet must be same as normal inference"
)
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTests:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
sd_pipe = self.get_pipeline(components_manager=cm)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_save_from_pretrained(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
sd_pipe.load_default_components(torch_dtype=torch.float32)
sd_pipe.to(torch_device)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class SDXLImg2ImgModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
class SDXLInpaintingModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
inputs["image"] = init_image
inputs["mask_image"] = mask_image
inputs["strength"] = 1.0
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.40872607,
0.38842705,
0.34893104,
0.47837183,
0.43792963,
0.5332134,
0.3716843,
0.47274873,
0.45000193,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -1,330 +0,0 @@
import gc
import unittest
from typing import Callable, Union
import numpy as np
import torch
import diffusers
from diffusers.utils import logging
from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
@require_torch
class ModularPipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each modular pipeline,
including:
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
- test_to_device: check if the pipeline's __call__ method can handle different devices
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def repo(self) -> str:
raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
"See existing pipeline tests for reference."
)
def get_pipeline(self):
raise NotImplementedError(
"You need to implement `get_pipeline(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
intermediate_parameters = pipe.blocks.intermediate_input_names
optional_parameters = pipe.default_call_parameters
def _check_for_parameters(parameters, expected_parameters, param_type):
remaining_parameters = {param for param in parameters if param not in expected_parameters}
assert (
len(remaining_parameters) == 0
), f"Required {param_type} parameters not present: {remaining_parameters}"
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# prepare batched inputs
batched_inputs = []
for batch_size in batch_sizes:
batched_input = {}
batched_input.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_input[name] = batch_size * [value]
if batch_generator and "generator" in inputs:
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_input["batch_size"] = batch_size
batched_inputs.append(batched_input)
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(
device == torch_device for device in model_devices
), "All pipeline components are not on accelerator device"
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@require_accelerator
def test_components_auto_cpu_offload(self):
base_pipe = self.get_pipeline().to(torch_device)
for component in base_pipe.components:
assert component.device == torch_device
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
offload_pipe = self.get_pipeline(components_manager=cm)
+26 -31
View File
@@ -20,6 +20,12 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
] ]
) )
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset( IMAGE_VARIATION_PARAMS = frozenset(
[ [
"image", "image",
@@ -29,6 +35,8 @@ IMAGE_VARIATION_PARAMS = frozenset(
] ]
) )
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset( TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[ [
"prompt", "prompt",
@@ -42,6 +50,8 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
] ]
) )
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[ [
# Text guided image variation with an image mask # Text guided image variation with an image mask
@@ -57,6 +67,8 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_PARAMS = frozenset( IMAGE_INPAINTING_PARAMS = frozenset(
[ [
# image variation with an image mask # image variation with an image mask
@@ -68,6 +80,8 @@ IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[ [
"example_image", "example_image",
@@ -79,12 +93,20 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset( TEXT_TO_AUDIO_PARAMS = frozenset(
[ [
"prompt", "prompt",
@@ -97,38 +119,11 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
] ]
) )
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
# image params
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
# batch params
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
# callback params
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])