Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 532b395718 | |||
| 5c43924ac2 | |||
| a633289e10 |
@@ -193,28 +193,24 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run torch hotswap + compile tests on GPU
|
||||
- name: Run torch compile tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "hotswap" --make-reports=tests_torch_hotswap_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
cat reports/tests_torch_hotswap_cuda_failures_short.txt
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_compile_hotswap_test_reports
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
|
||||
run_big_gpu_torch_tests:
|
||||
|
||||
@@ -189,7 +189,6 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
RUN_SLOW: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
|
||||
@@ -763,7 +763,4 @@ class LegacyConfigMixin(ConfigMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
if remapped_class is cls:
|
||||
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
|
||||
else:
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing_extensions import Self
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
@@ -431,7 +431,10 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1690,7 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -2150,7 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..models.embeddings import (
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -82,6 +82,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -157,6 +158,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
key_id += 1
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -82,6 +82,7 @@ class SD3Transformer2DLoadersMixin:
|
||||
)
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
@@ -151,6 +152,7 @@ class SD3Transformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
@@ -755,6 +755,7 @@ class UNet2DConditionLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -853,6 +854,7 @@ class UNet2DConditionLoadersMixin:
|
||||
key_id += 2
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -1540,7 +1540,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
@@ -1877,9 +1880,4 @@ class LegacyModelMixin(ModelMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
if remapped_class is cls:
|
||||
return super(LegacyModelMixin, remapped_class).from_pretrained(
|
||||
pretrained_model_name_or_path, **kwargs_copy
|
||||
)
|
||||
else:
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
|
||||
@@ -323,6 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -333,6 +334,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -358,7 +367,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError("TODO")
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
@@ -367,7 +378,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
is_modular=True,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class ComponentSpec:
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -37,13 +38,7 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetUnionModel,
|
||||
ImageProjection,
|
||||
MultiControlNetUnionModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
@@ -267,9 +262,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: Union[
|
||||
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
||||
],
|
||||
controlnet: ControlNetUnionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
@@ -279,8 +272,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetUnionModel(controlnet)
|
||||
if not isinstance(controlnet, ControlNetUnionModel):
|
||||
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -656,7 +649,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
control_mode=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
@@ -730,44 +722,28 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
elif not all(isinstance(i, list) for i in image):
|
||||
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
if not isinstance(control_guidance_end, (tuple, list)):
|
||||
control_guidance_end = [control_guidance_end]
|
||||
|
||||
@@ -786,15 +762,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Check `control_mode`
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if max(control_mode) >= controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
@@ -1082,7 +1049,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
@@ -1107,7 +1074,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
@@ -1137,13 +1104,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accept
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
control_image (`PipelineImageInput`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
|
||||
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
||||
and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
|
||||
init, images must be passed as a list such that each element of the list can be correctly batched for
|
||||
input to a single controlnet.
|
||||
height (`int`, *optional*, defaults to the size of control_image):
|
||||
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
@@ -1217,21 +1184,16 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
||||
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
||||
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
||||
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
||||
conditions in control_image
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
@@ -1311,6 +1273,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1319,56 +1287,37 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
if len(control_image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
control_mode,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
control_type = [0 for _ in range(num_control_type)]
|
||||
for _image, control_idx in zip(control_image, control_mode):
|
||||
control_type[control_idx] = 1
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
||||
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
||||
]
|
||||
control_type = torch.Tensor(control_type)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
@@ -1385,11 +1334,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetUnionModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
global_pool_conditions = controlnet.config.global_pool_conditions
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3.1. Encode input prompt
|
||||
@@ -1427,55 +1372,22 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 4.1 Prepare image
|
||||
# 4. Prepare image and controlnet_conditioning_image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
# 4.2 Prepare control images
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for image_ in control_image:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(image_)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0].shape[-2:]
|
||||
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
images = []
|
||||
|
||||
for image_ in control_image_:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
control_images.append(images)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
for idx, _ in enumerate(control_image):
|
||||
control_image[idx] = self.prepare_control_image(
|
||||
image=control_image[idx],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image[idx].shape[-2:]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1502,11 +1414,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
controlnet_keep.append(
|
||||
1.0
|
||||
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
||||
)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
original_size = original_size or (height, width)
|
||||
@@ -1549,25 +1460,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -840,8 +840,6 @@ class FluxPipeline(
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
|
||||
@@ -383,8 +383,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = np.array([noise_level]).astype(np.int64)
|
||||
|
||||
@@ -153,8 +153,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -234,9 +232,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -246,9 +242,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
|
||||
@@ -230,8 +230,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -332,7 +330,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -348,9 +345,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||
must be `None`, and `timestep_spacing` attribute will be ignored.
|
||||
"""
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if num_inference_steps is None and timesteps is None:
|
||||
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
|
||||
@@ -169,8 +169,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -303,7 +301,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -319,9 +316,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
||||
passed, `num_inference_steps` must be `None`.
|
||||
"""
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if num_inference_steps is None and timesteps is None:
|
||||
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
|
||||
@@ -212,8 +212,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
steps_offset: int = 0,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -300,9 +298,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -313,9 +309,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
|
||||
Reference in New Issue
Block a user