Compare commits

...

1 Commits

Author SHA1 Message Date
DN6 c8a7617536 update 2025-05-12 19:37:28 +05:30
21 changed files with 927 additions and 754 deletions
+2 -2
View File
@@ -761,8 +761,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig, LayerSkipConfig,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig, SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache, apply_faster_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast, apply_pyramid_attention_broadcast,
) )
from .models import ( from .models import (
@@ -1085,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionSAGPipeline, StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
StableDiffusionXLAdapterPipeline, StableDiffusionXLAdapterPipeline,
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline,
@@ -1102,7 +1103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline, StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
StableVideoDiffusionPipeline, StableVideoDiffusionPipeline,
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -119,19 +120,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
def _is_apg_enabled(self) -> bool: def _is_apg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
@@ -156,25 +157,25 @@ def normalized_guidance(
): ):
diff = pred_cond - pred_uncond diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))] dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None: if momentum_buffer is not None:
momentum_buffer.update(diff) momentum_buffer.update(diff)
diff = momentum_buffer.running_average diff = momentum_buffer.running_average
if norm_threshold > 0: if norm_threshold > 0:
ones = torch.ones_like(diff) ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True) diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm) scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double() v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim) v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update pred = pred + guidance_scale * normalized_update
return pred return pred
+9 -8
View File
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -113,13 +114,13 @@ class AutoGuidance(BaseGuidance):
if self._is_ag_enabled() and self.is_unconditional: if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name) _apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None: def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional: if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names: for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True) registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
@@ -140,9 +141,9 @@ class AutoGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return pred, {}
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
return self._count_prepared == 1 return self._count_prepared == 1
@@ -157,17 +158,17 @@ class AutoGuidance(BaseGuidance):
def _is_ag_enabled(self) -> bool: def _is_ag_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -74,7 +75,7 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
@@ -112,17 +113,17 @@ class ClassifierFreeGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool: def _is_cfg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -72,7 +73,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.zero_init_steps = zero_init_steps self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
@@ -102,7 +103,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return pred, {}
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
return self._count_prepared == 1 return self._count_prepared == 1
@@ -117,19 +118,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool: def _is_cfg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
+8 -8
View File
@@ -58,10 +58,10 @@ class BaseGuidance:
def disable(self): def disable(self):
self._enabled = False self._enabled = False
def enable(self): def enable(self):
self._enabled = True self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step self._step = step
self._num_inference_steps = num_inference_steps self._num_inference_steps = num_inference_steps
@@ -104,14 +104,14 @@ class BaseGuidance:
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
) )
self._input_fields = kwargs self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None: def prepare_models(self, denoiser: torch.nn.Module) -> None:
""" """
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic. subclasses to implement specific model preparation logic.
""" """
self._count_prepared += 1 self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None: def cleanup_models(self, denoiser: torch.nn.Module) -> None:
""" """
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
@@ -119,7 +119,7 @@ class BaseGuidance:
modifications made during `prepare_models`. modifications made during `prepare_models`.
""" """
pass pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
@@ -139,15 +139,15 @@ class BaseGuidance:
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property @property
def is_unconditional(self) -> bool: def is_unconditional(self) -> bool:
return not self.is_conditional return not self.is_conditional
@property @property
def num_conditions(self) -> int: def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod @classmethod
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
""" """
+11 -10
View File
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -148,14 +149,14 @@ class SkipLayerGuidance(BaseGuidance):
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name) _apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None: def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference # Remove the hooks after inference
for hook_name in self._skip_layer_hook_names: for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True) registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
if self.num_conditions == 1: if self.num_conditions == 1:
tuple_indices = [0] tuple_indices = [0]
@@ -200,7 +201,7 @@ class SkipLayerGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return pred, {}
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3 return self._count_prepared == 1 or self._count_prepared == 3
@@ -217,31 +218,31 @@ class SkipLayerGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool: def _is_cfg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
def _is_slg_enabled(self) -> bool: def _is_slg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero return is_within_range and not is_zero
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -141,14 +142,14 @@ class SmoothedEnergyGuidance(BaseGuidance):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module): def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference # Remove the hooks after inference
for hook_name in self._seg_layer_hook_names: for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True) registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
if self.num_conditions == 1: if self.num_conditions == 1:
tuple_indices = [0] tuple_indices = [0]
@@ -193,7 +194,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return pred, {}
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3 return self._count_prepared == 1 or self._count_prepared == 3
@@ -210,31 +211,31 @@ class SmoothedEnergyGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool: def _is_cfg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
def _is_seg_enabled(self) -> bool: def _is_seg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0) is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero return is_within_range and not is_zero
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -97,24 +98,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
def _is_tcfg_enabled(self) -> bool: def _is_tcfg_enabled(self) -> bool:
if not self._enabled: if not self._enabled:
return False return False
is_within_range = True is_within_range = True
if self._num_inference_steps is not None: if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps) skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False is_close = False
if self.use_original_formulation: if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0) is_close = math.isclose(self.guidance_scale, 0.0)
else: else:
is_close = math.isclose(self.guidance_scale, 1.0) is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close return is_within_range and not is_close
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
cond_dtype = pred_cond.dtype cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float() preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2) preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False) U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
@@ -125,9 +126,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified) x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift pred = pred + guidance_scale * shift
return pred return pred
+10 -5
View File
@@ -20,7 +20,12 @@ import torch
from ..utils import get_logger from ..utils import get_logger
from ..utils.torch_utils import unwrap_module from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
@@ -196,15 +201,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
for i, block in enumerate(transformer_blocks): for i, block in enumerate(transformer_blocks):
if i not in config.indices: if i not in config.indices:
continue continue
blocks_found = True blocks_found = True
if config.skip_attention and config.skip_ff: if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block) registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout) hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name) registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores: elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules(): for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
@@ -213,7 +218,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
registry = HookRegistry.check_if_exists_or_initialize(submodule) registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name) registry.register_hook(hook, name)
if config.skip_ff: if config.skip_ff:
for submodule_name, submodule in block.named_modules(): for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES): if isinstance(submodule, _FEEDFORWARD_CLASSES):
@@ -14,7 +14,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook):
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto": if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier): if hasattr(module, identifier):
@@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks." "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
) )
if config._query_proj_identifiers is None: if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"] config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn) transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False blocks_found = False
for i, block in enumerate(transformer_blocks): for i, block in enumerate(transformer_blocks):
if i not in config.indices: if i not in config.indices:
continue continue
blocks_found = True blocks_found = True
for submodule_name, submodule in block.named_modules(): for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue continue
@@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
registry = HookRegistry.check_if_exists_or_initialize(query_proj) registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma) hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name) registry.register_hook(hook, name)
if not blocks_found: if not blocks_found:
raise ValueError( raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and " f"Could not find any transformer blocks matching the provided indices {config.indices} and "
@@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
in the future without warning or guarantee of reproducibility. in the future without warning or guarantee of reproducibility.
""" """
assert query.ndim == 3 assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape batch_size, seq_len, embed_dim = query.shape
@@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query[:, :num_square_tokens, :] query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1) query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf: if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2 kernel_size_half = (kernel_size - 1) / 2
@@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1) query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone() query[:, :num_square_tokens, :] = query_slice.clone()
return query return query
+1 -1
View File
@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import ( from .ip_adapter import (
FluxIPAdapterMixin, FluxIPAdapterMixin,
IPAdapterMixin, IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin, ModularIPAdapterMixin,
SD3IPAdapterMixin,
) )
from .lora_pipeline import ( from .lora_pipeline import (
AmusedLoraLoaderMixin, AmusedLoraLoaderMixin,
+1 -1
View File
@@ -703,12 +703,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_sag import StableDiffusionSAGPipeline
from .stable_diffusion_xl import ( from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader, StableDiffusionXLModularLoader,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
) )
from .stable_video_diffusion import StableVideoDiffusionPipeline from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import ( from .t2i_adapter import (
+76 -78
View File
@@ -12,21 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import time
from collections import OrderedDict from collections import OrderedDict
from itertools import combinations from itertools import combinations
from typing import List, Optional, Union, Dict, Any from typing import Any, Dict, List, Optional, Union
import copy
import torch import torch
import time
from dataclasses import dataclass
from ..utils import ( from ..utils import (
is_accelerate_available, is_accelerate_available,
logging, logging,
) )
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
if is_accelerate_available(): if is_accelerate_available():
@@ -231,17 +228,18 @@ class AutoOffloadStrategy:
from .modular_pipeline_utils import ComponentSpec
import uuid import uuid
class ComponentsManager: class ComponentsManager:
def __init__(self): def __init__(self):
self.components = OrderedDict() self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None self.model_hooks = None
self._auto_offload_enabled = False self._auto_offload_enabled = False
def _get_by_collection(self, collection: str): def _get_by_collection(self, collection: str):
""" """
Select components by collection name. Select components by collection name.
@@ -252,8 +250,8 @@ class ComponentsManager:
for component_id in component_ids: for component_id in component_ids:
selected_components[component_id] = self.components[component_id] selected_components[component_id] = self.components[component_id]
return selected_components return selected_components
def _get_by_load_id(self, load_id: str): def _get_by_load_id(self, load_id: str):
""" """
Select components by its load_id. Select components by its load_id.
@@ -263,8 +261,8 @@ class ComponentsManager:
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
selected_components[name] = component selected_components[name] = component
return selected_components return selected_components
def add(self, name, component, collection: Optional[str] = None): def add(self, name, component, collection: Optional[str] = None):
for comp_id, comp in self.components.items(): for comp_id, comp in self.components.items():
@@ -282,7 +280,7 @@ class ComponentsManager:
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_name>')`." f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
) )
# add component to components manager # add component to components manager
self.components[component_id] = component self.components[component_id] = component
@@ -293,8 +291,8 @@ class ComponentsManager:
self.collections[collection].add(component_id) self.collections[collection].add(component_id)
if self._auto_offload_enabled: if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device) self.enable_auto_cpu_offload(self._auto_offload_device)
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
return component_id return component_id
@@ -304,14 +302,14 @@ class ComponentsManager:
if name not in self.components: if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager") logger.warning(f"Component '{name}' not found in ComponentsManager")
return return
self.components.pop(name) self.components.pop(name)
self.added_time.pop(name) self.added_time.pop(name)
for collection in self.collections: for collection in self.collections:
if name in self.collections[collection]: if name in self.collections[collection]:
self.collections[collection].remove(name) self.collections[collection].remove(name)
if self._auto_offload_enabled: if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device) self.enable_auto_cpu_offload(self._auto_offload_device)
@@ -341,7 +339,7 @@ class ComponentsManager:
Dictionary mapping component IDs to components, Dictionary mapping component IDs to components,
or list of (base_name, component) tuples if as_name_component_tuples=True or list of (base_name, component) tuples if as_name_component_tuples=True
""" """
if collection: if collection:
if collection not in self.collections: if collection not in self.collections:
logger.warning(f"Collection '{collection}' not found in ComponentsManager") logger.warning(f"Collection '{collection}' not found in ComponentsManager")
@@ -360,16 +358,16 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1]) return '_'.join(parts[:-1])
return component_id return component_id
if names is None: if names is None:
if as_name_component_tuples: if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
else: else:
return components return components
# Create mapping from component_id to base_name for all components # Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
def matches_pattern(component_id, pattern, exact_match=False): def matches_pattern(component_id, pattern, exact_match=False):
""" """
Helper function to check if a component matches a pattern based on its base name. Helper function to check if a component matches a pattern based on its base name.
@@ -380,124 +378,124 @@ class ComponentsManager:
exact_match: If True, only exact matches to base_name are considered exact_match: If True, only exact matches to base_name are considered
""" """
base_name = base_names[component_id] base_name = base_names[component_id]
# Exact match with base name # Exact match with base name
if exact_match: if exact_match:
return pattern == base_name return pattern == base_name
# Prefix match (ends with *) # Prefix match (ends with *)
elif pattern.endswith('*'): elif pattern.endswith('*'):
prefix = pattern[:-1] prefix = pattern[:-1]
return base_name.startswith(prefix) return base_name.startswith(prefix)
# Contains match (starts with *) # Contains match (starts with *)
elif pattern.startswith('*'): elif pattern.startswith('*'):
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
return search in base_name return search in base_name
# Exact match (no wildcards) # Exact match (no wildcards)
else: else:
return pattern == base_name return pattern == base_name
if isinstance(names, str): if isinstance(names, str):
# Check if this is a "not" pattern # Check if this is a "not" pattern
is_not_pattern = names.startswith('!') is_not_pattern = names.startswith('!')
if is_not_pattern: if is_not_pattern:
names = names[1:] # Remove the ! prefix names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |) # Handle OR patterns (containing |)
if '|' in names: if '|' in names:
terms = names.split('|') terms = names.split('|')
matches = {} matches = {}
for comp_id, comp in components.items(): for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names # For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
# Check if any of the terms match this component # Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern # Flip the decision if this is a NOT pattern
if is_not_pattern: if is_not_pattern:
should_include = not should_include should_include = not should_include
if should_include: if should_include:
matches[comp_id] = comp matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else "" log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns" match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name # Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()): elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name # Find all components with this base name
matches = { matches = {
comp_id: comp for comp_id, comp in components.items() comp_id: comp for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern if (base_names[comp_id] == names) != is_not_pattern
} }
if is_not_pattern: if is_not_pattern:
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else: else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *) # Prefix match (ends with *)
elif names.endswith('*'): elif names.endswith('*'):
prefix = names[:-1] prefix = names[:-1]
matches = { matches = {
comp_id: comp for comp_id, comp in components.items() comp_id: comp for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern if base_names[comp_id].startswith(prefix) != is_not_pattern
} }
if is_not_pattern: if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else: else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *) # Contains match (starts with *)
elif names.startswith('*'): elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:] search = names[1:-1] if names.endswith('*') else names[1:]
matches = { matches = {
comp_id: comp for comp_id, comp in components.items() comp_id: comp for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern if (search in base_names[comp_id]) != is_not_pattern
} }
if is_not_pattern: if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else: else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}") logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name) # Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()): elif any(names in base_name for base_name in base_names.values()):
matches = { matches = {
comp_id: comp for comp_id, comp in components.items() comp_id: comp for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern if (names in base_names[comp_id]) != is_not_pattern
} }
if is_not_pattern: if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else: else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}") logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else: else:
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches: if not matches:
raise ValueError(f"No components found matching pattern '{names}'") raise ValueError(f"No components found matching pattern '{names}'")
if as_name_component_tuples: if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
else: else:
return matches return matches
elif isinstance(names, list): elif isinstance(names, list):
results = {} results = {}
for name in names: for name in names:
result = self.get(name, collection, load_id, as_name_component_tuples=False) result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result) results.update(result)
if as_name_component_tuples: if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()] return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
else: else:
return results return results
else: else:
raise ValueError(f"Invalid type for names: {type(names)}") raise ValueError(f"Invalid type for names: {type(names)}")
@@ -558,14 +556,14 @@ class ComponentsManager:
raise ValueError(f"Component '{name}' not found in ComponentsManager") raise ValueError(f"Component '{name}' not found in ComponentsManager")
component = self.components[name] component = self.components[name]
# Build complete info dict first # Build complete info dict first
info = { info = {
"model_id": name, "model_id": name,
"added_time": self.added_time[name], "added_time": self.added_time[name],
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None), "collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
} }
# Additional info for torch.nn.Module components # Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module): if isinstance(component, torch.nn.Module):
# Check for hook information # Check for hook information
@@ -573,7 +571,7 @@ class ComponentsManager:
execution_device = None execution_device = None
if has_hook and hasattr(component._hf_hook, "execution_device"): if has_hook and hasattr(component._hf_hook, "execution_device"):
execution_device = component._hf_hook.execution_device execution_device = component._hf_hook.execution_device
info.update({ info.update({
"class_name": component.__class__.__name__, "class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3), "size_gb": get_memory_footprint(component) / (1024**3),
@@ -594,8 +592,8 @@ class ComponentsManager:
if any("IPAdapter" in ptype for ptype in processor_types): if any("IPAdapter" in ptype for ptype in processor_types):
# Then get scales only from IP-Adapter processors # Then get scales only from IP-Adapter processors
scales = { scales = {
k: v.scale k: v.scale
for k, v in processors.items() for k, v in processors.items()
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
} }
if scales: if scales:
@@ -609,7 +607,7 @@ class ComponentsManager:
else: else:
# List of fields requested, return dict with just those fields # List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields} return {k: v for k, v in info.items() if k in fields}
return info return info
def __repr__(self): def __repr__(self):
@@ -622,13 +620,13 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1]) return '_'.join(parts[:-1])
return name return name
# Extract load_id if available # Extract load_id if available
def get_load_id(component): def get_load_id(component):
if hasattr(component, "_diffusers_load_id"): if hasattr(component, "_diffusers_load_id"):
return component._diffusers_load_id return component._diffusers_load_id
return "N/A" return "N/A"
# Format device info compactly # Format device info compactly
def format_device(component, info): def format_device(component, info):
if not info["has_hook"]: if not info["has_hook"]:
@@ -637,24 +635,24 @@ class ComponentsManager:
device = str(getattr(component, 'device', 'N/A')) device = str(getattr(component, 'device', 'N/A'))
exec_device = str(info['execution_device'] or 'N/A') exec_device = str(info['execution_device'] or 'N/A')
return f"{device}({exec_device})" return f"{device}({exec_device})"
# Get all simple names to calculate width # Get all simple names to calculate width
simple_names = [get_simple_name(id) for id in self.components.keys()] simple_names = [get_simple_name(id) for id in self.components.keys()]
# Get max length of load_ids for models # Get max length of load_ids for models
load_ids = [ load_ids = [
get_load_id(component) get_load_id(component)
for component in self.components.values() for component in self.components.values()
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
] ]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Collection names # Collection names
collection_names = [ collection_names = [
next((coll for coll, comps in self.collections.items() if name in comps), "N/A") next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
for name in self.components.keys() for name in self.components.keys()
] ]
col_widths = { col_widths = {
"name": max(15, max(len(name) for name in simple_names)), "name": max(15, max(len(name) for name in simple_names)),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
@@ -692,7 +690,7 @@ class ComponentsManager:
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
load_id = get_load_id(component) load_id = get_load_id(component)
collection = info["collection"] or "N/A" collection = info["collection"] or "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
@@ -712,7 +710,7 @@ class ComponentsManager:
info = self.get_model_info(name) info = self.get_model_info(name)
simple_name = get_simple_name(name) simple_name = get_simple_name(name)
collection = info["collection"] or "N/A" collection = info["collection"] or "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
output += dash_line output += dash_line
@@ -726,9 +724,9 @@ class ComponentsManager:
if info.get("adapters") is not None: if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n" output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"): if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n" output += " IP-Adapter: Enabled\n"
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
return output return output
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
@@ -759,13 +757,13 @@ class ComponentsManager:
from ..pipelines.pipeline_utils import DiffusionPipeline from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items(): for name, component in pipe.components.items():
if component is None: if component is None:
continue continue
# Add prefix if specified # Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components: if component_name not in self.components:
self.add(component_name, component) self.add(component_name, component)
else: else:
@@ -791,13 +789,13 @@ class ComponentsManager:
ValueError: If no components match or multiple components match ValueError: If no components match or multiple components match
""" """
results = self.get(name, collection, load_id) results = self.get(name, collection, load_id)
if not results: if not results:
raise ValueError(f"No components found matching '{name}'") raise ValueError(f"No components found matching '{name}'")
if len(results) > 1: if len(results) > 1:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values())) return next(iter(results.values()))
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
@@ -823,17 +821,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
if value_tuple not in value_to_keys: if value_tuple not in value_to_keys:
value_to_keys[value_tuple] = [] value_to_keys[value_tuple] = []
value_to_keys[value_tuple].append(key) value_to_keys[value_tuple].append(key)
def find_common_prefix(keys: List[str]) -> str: def find_common_prefix(keys: List[str]) -> str:
"""Find the shortest common prefix among a list of dot-separated keys.""" """Find the shortest common prefix among a list of dot-separated keys."""
if not keys: if not keys:
return "" return ""
if len(keys) == 1: if len(keys) == 1:
return keys[0] return keys[0]
# Split all keys into parts # Split all keys into parts
key_parts = [k.split('.') for k in keys] key_parts = [k.split('.') for k in keys]
# Find how many initial parts are common # Find how many initial parts are common
common_length = 0 common_length = 0
for parts in zip(*key_parts): for parts in zip(*key_parts):
@@ -841,10 +839,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
common_length += 1 common_length += 1
else: else:
break break
if common_length == 0: if common_length == 0:
return "" return ""
# Return the common prefix # Return the common prefix
return '.'.join(key_parts[0][:common_length]) return '.'.join(key_parts[0][:common_length])
@@ -858,5 +856,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
summary[prefix] = value summary[prefix] = value
else: else:
summary[""] = value # Use empty string if no common prefix summary[""] = value # Use empty string if no common prefix
return summary return summary
File diff suppressed because it is too large Load Diff
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import inspect import inspect
from dataclasses import dataclass, asdict, field, fields import re
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils.import_utils import is_torch_available from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
if is_torch_available(): if is_torch_available():
import torch import torch
@@ -56,50 +57,50 @@ class ComponentSpec:
variant: Optional[str] = field(default=None, metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self): def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value.""" """Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method)) return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other): def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id.""" """Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec): if not isinstance(other, ComponentSpec):
return False return False
return (self.name == other.name and return (self.name == other.name and
self.load_id == other.load_id and self.load_id == other.load_id and
self.default_creation_method == other.default_creation_method) self.default_creation_method == other.default_creation_method)
@classmethod @classmethod
def from_component(cls, name: str, component: torch.nn.Module) -> Any: def from_component(cls, name: str, component: torch.nn.Module) -> Any:
"""Create a ComponentSpec from a Component created by `create` method.""" """Create a ComponentSpec from a Component created by `create` method."""
if not hasattr(component, "_diffusers_load_id"): if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` method") raise ValueError("Component is not created by `create` method")
type_hint = component.__class__ type_hint = component.__class__
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
config = component.config config = component.config
else: else:
config = None config = None
load_spec = cls.decode_load_id(component._diffusers_load_id) load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, **load_spec) return cls(name=name, type_hint=type_hint, config=config, **load_spec)
@classmethod @classmethod
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
"""Create a ComponentSpec from a load_id string.""" """Create a ComponentSpec from a load_id string."""
if load_id == "null": if load_id == "null":
raise ValueError("Cannot create ComponentSpec from null load_id") raise ValueError("Cannot create ComponentSpec from null load_id")
# Decode the load_id into a dictionary of loading fields # Decode the load_id into a dictionary of loading fields
load_fields = cls.decode_load_id(load_id) load_fields = cls.decode_load_id(load_id)
# Create a new ComponentSpec instance with the decoded fields # Create a new ComponentSpec instance with the decoded fields
return cls(name=name, **load_fields) return cls(name=name, **load_fields)
@classmethod @classmethod
def loading_fields(cls) -> List[str]: def loading_fields(cls) -> List[str]:
""" """
@@ -107,8 +108,8 @@ class ComponentSpec:
(i.e. those whose field.metadata["loading"] is True). (i.e. those whose field.metadata["loading"] is True).
""" """
return [f.name for f in fields(cls) if f.metadata.get("loading", False)] return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property @property
def load_id(self) -> str: def load_id(self) -> str:
""" """
@@ -118,7 +119,7 @@ class ComponentSpec:
parts = [getattr(self, k) for k in self.loading_fields()] parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts] parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p) return "|".join(p for p in parts if p)
@classmethod @classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
""" """
@@ -139,29 +140,29 @@ class ComponentSpec:
If a segment value is "null", it's replaced with None. If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not loaded from pretrained). Returns None if load_id is "null" (indicating component not loaded from pretrained).
""" """
# Get all loading fields in order # Get all loading fields in order
loading_fields = cls.loading_fields() loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields} result = {f: None for f in loading_fields}
if load_id == "null": if load_id == "null":
return result return result
# Split the load_id # Split the load_id
parts = load_id.split("|") parts = load_id.split("|")
# Map parts to loading fields by position # Map parts to loading fields by position
for i, part in enumerate(parts): for i, part in enumerate(parts):
if i < len(loading_fields): if i < len(loading_fields):
# Convert "null" string back to None # Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part result[loading_fields[i]] = None if part == "null" else part
return result return result
# YiYi TODO: add validator # YiYi TODO: add validator
def create(self, **kwargs) -> Any: def create(self, **kwargs) -> Any:
"""Create the component using the preferred creation method.""" """Create the component using the preferred creation method."""
# from_pretrained creation # from_pretrained creation
if self.default_creation_method == "from_pretrained": if self.default_creation_method == "from_pretrained":
return self.create_from_pretrained(**kwargs) return self.create_from_pretrained(**kwargs)
@@ -170,17 +171,17 @@ class ComponentSpec:
return self.create_from_config(**kwargs) return self.create_from_config(**kwargs)
else: else:
raise ValueError(f"Invalid creation method: {self.default_creation_method}") raise ValueError(f"Invalid creation method: {self.default_creation_method}")
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config.""" """Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type): if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError( raise ValueError(
f"`type_hint` is required when using from_config creation method." "`type_hint` is required when using from_config creation method."
) )
config = config or self.config or {} config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin): if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs) component = self.type_hint.from_config(config, **kwargs)
else: else:
@@ -193,24 +194,24 @@ class ComponentSpec:
if k in signature_params: if k in signature_params:
init_kwargs[k] = v init_kwargs[k] = v
component = self.type_hint(**init_kwargs) component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null" component._diffusers_load_id = "null"
if hasattr(component, "config"): if hasattr(component, "config"):
self.config = component.config self.config = component.config
return component return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained # YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def create_from_pretrained(self, **kwargs) -> Any: def create_from_pretrained(self, **kwargs) -> Any:
"""Create component using from_pretrained.""" """Create component using from_pretrained."""
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None) repo = load_kwargs.pop("repo", None)
if repo is None: if repo is None:
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None: if self.type_hint is None:
try: try:
from diffusers import AutoModel from diffusers import AutoModel
@@ -223,19 +224,19 @@ class ComponentSpec:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e: except Exception as e:
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
if repo != self.repo: if repo != self.repo:
self.repo = repo self.repo = repo
for k, v in passed_loading_kwargs.items(): for k, v in passed_loading_kwargs.items():
if v is not None: if v is not None:
setattr(self, k, v) setattr(self, k, v)
component._diffusers_load_id = self.load_id component._diffusers_load_id = self.load_id
return component return component
@dataclass
@dataclass
class ConfigSpec: class ConfigSpec:
"""Specification for a pipeline configuration parameter.""" """Specification for a pipeline configuration parameter."""
name: str name: str
@@ -254,7 +255,7 @@ class InputParam:
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass @dataclass
class OutputParam: class OutputParam:
"""Specification for an output parameter.""" """Specification for an output parameter."""
name: str name: str
@@ -287,14 +288,14 @@ def format_inputs_short(inputs):
""" """
required_inputs = [param for param in inputs if param.required] required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required] optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs) required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str inputs_str = required_str
if optional_str: if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str return inputs_str
@@ -321,18 +322,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
input_parts.append(f"Required({inp.name})") input_parts.append(f"Required({inp.name})")
else: else:
input_parts.append(inp.name) input_parts.append(inp.name)
# Handle modified variables (appear in both inputs and outputs) # Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs} inputs_set = {inp.name for inp in intermediates_inputs}
modified_parts = [] modified_parts = []
new_output_parts = [] new_output_parts = []
for out in intermediates_outputs: for out in intermediates_outputs:
if out.name in inputs_set: if out.name in inputs_set:
modified_parts.append(out.name) modified_parts.append(out.name)
else: else:
new_output_parts.append(out.name) new_output_parts.append(out.name)
result = [] result = []
if input_parts: if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}") result.append(f" - inputs: {', '.join(input_parts)}")
@@ -340,7 +341,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
result.append(f" - modified: {', '.join(modified_parts)}") result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts: if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}") result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)" return "\n".join(result) if result else " (none)"
@@ -358,18 +359,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
""" """
if not params: if not params:
return "" return ""
base_indent = " " * indent_level base_indent = " " * indent_level
param_indent = " " * (indent_level + 4) param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8) desc_indent = " " * (indent_level + 8)
formatted_params = [] formatted_params = []
def get_type_str(type_hint): def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]" return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length): def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation.""" """Wrap text while preserving markdown links and maintaining indentation."""
words = text.split() words = text.split()
@@ -379,7 +380,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
for word in words: for word in words:
word_length = len(word) + (1 if current_line else 0) word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length: if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line)) lines.append(" ".join(current_line))
current_line = [word] current_line = [word]
@@ -387,20 +388,20 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
else: else:
current_line.append(word) current_line.append(word)
current_length += word_length current_length += word_length
if current_line: if current_line:
lines.append(" ".join(current_line)) lines.append(" ".join(current_line))
return f"\n{indent}".join(lines) return f"\n{indent}".join(lines)
# Add the header # Add the header
formatted_params.append(f"{base_indent}{header}:") formatted_params.append(f"{base_indent}{header}:")
for param in params: for param in params:
# Format parameter name and type # Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
param_str = f"{param_indent}{param.name} (`{type_str}`" param_str = f"{param_indent}{param.name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional # Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"): if hasattr(param, "required"):
if not param.required: if not param.required:
@@ -408,7 +409,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
if param.default is not None: if param.default is not None:
param_str += f", defaults to {param.default}" param_str += f", defaults to {param.default}"
param_str += "):" param_str += "):"
# Add description on a new line with additional indentation and wrapping # Add description on a new line with additional indentation and wrapping
if param.description: if param.description:
desc = re.sub( desc = re.sub(
@@ -418,9 +419,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
) )
wrapped_desc = wrap_text(desc, desc_indent, max_line_length) wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}" param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str) formatted_params.append(param_str)
return "\n\n".join(formatted_params) return "\n\n".join(formatted_params)
@@ -466,42 +467,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
""" """
if not components: if not components:
return "" return ""
base_indent = " " * indent_level base_indent = " " * indent_level
component_indent = " " * (indent_level + 4) component_indent = " " * (indent_level + 4)
formatted_components = [] formatted_components = []
# Add the header # Add the header
formatted_components.append(f"{base_indent}Components:") formatted_components.append(f"{base_indent}Components:")
if add_empty_lines: if add_empty_lines:
formatted_components.append("") formatted_components.append("")
# Add each component with optional empty lines between them # Add each component with optional empty lines between them
for i, component in enumerate(components): for i, component in enumerate(components):
# Get type name, handling special cases # Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)" component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description: if component.description:
component_desc += f": {component.description}" component_desc += f": {component.description}"
# Get the loading fields dynamically # Get the loading fields dynamically
loading_field_values = [] loading_field_values = []
for field_name in component.loading_fields(): for field_name in component.loading_fields():
field_value = getattr(component, field_name) field_value = getattr(component, field_name)
if field_value is not None: if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}") loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available # Add loading field information if available
if loading_field_values: if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]" component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc) formatted_components.append(component_desc)
# Add an empty line after each component except the last one # Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1: if add_empty_lines and i < len(components) - 1:
formatted_components.append("") formatted_components.append("")
return "\n".join(formatted_components) return "\n".join(formatted_components)
@@ -519,27 +520,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
""" """
if not configs: if not configs:
return "" return ""
base_indent = " " * indent_level base_indent = " " * indent_level
config_indent = " " * (indent_level + 4) config_indent = " " * (indent_level + 4)
formatted_configs = [] formatted_configs = []
# Add the header # Add the header
formatted_configs.append(f"{base_indent}Configs:") formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines: if add_empty_lines:
formatted_configs.append("") formatted_configs.append("")
# Add each config with optional empty lines between them # Add each config with optional empty lines between them
for i, config in enumerate(configs): for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})" config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description: if config.description:
config_desc += f": {config.description}" config_desc += f": {config.description}"
formatted_configs.append(config_desc) formatted_configs.append(config_desc)
# Add an empty line after each config except the last one # Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1: if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("") formatted_configs.append("")
return "\n".join(formatted_configs) return "\n".join(formatted_configs)
@@ -584,9 +585,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class
# Add inputs section # Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2) output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section # Add outputs section
output += "\n\n" output += "\n\n"
output += format_output_params(outputs, indent_level=2) output += format_output_params(outputs, indent_level=2)
return output return output
@@ -334,6 +334,7 @@ def maybe_raise_or_warn(
# a simpler version of get_class_obj_and_candidates, it won't work with custom code # a simpler version of get_class_obj_and_candidates, it won't work with custom code
def simple_get_class_obj(library_name, class_name): def simple_get_class_obj(library_name, class_name):
from diffusers import pipelines from diffusers import pipelines
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
if is_pipeline_module: if is_pipeline_module:
@@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name):
return class_obj return class_obj
def get_class_obj_and_candidates( def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
): ):
+1 -1
View File
@@ -1120,7 +1120,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
automatically detect the available accelerator and use. automatically detect the available accelerator and use.
""" """
self._maybe_raise_error_if_group_offload_active(raise_error=True) self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
@@ -61,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
from .pipeline_stable_diffusion_xl_modular import ( from .pipeline_stable_diffusion_xl_modular import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDecodeLatentsStep,
StableDiffusionXLDenoiseStep, StableDiffusionXLDenoiseStep,
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep, StableDiffusionXLSetTimestepsStep,
StableDiffusionXLTextEncoderStep, StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoPipeline,
) )
try: try:
File diff suppressed because it is too large Load Diff
+148 -6
View File
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Utilities to dynamically load objects from the Hub.""" """Utilities to dynamically load objects from the Hub."""
import hashlib
import importlib import importlib
import inspect import inspect
import json import json
@@ -21,8 +22,9 @@ import os
import re import re
import shutil import shutil
import sys import sys
import threading
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, ModuleType, Optional, Union
from urllib import request from urllib import request
from huggingface_hub import hf_hub_download, model_info from huggingface_hub import hf_hub_download, model_info
@@ -37,6 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions(): def get_diffusers_versions():
@@ -154,15 +157,132 @@ def check_imports(filename):
return get_relative_imports(filename) return get_relative_imports(filename)
def get_class_in_module(class_name, module_path): def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_modular_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
force_reload (`bool`, *optional*, defaults to `False`):
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
Otherwise, the module is only reloaded if the file has changed.
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
# reload in both cases, unless the module is already imported and the hash hits
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return getattr(module, class_name)
def get_class_in_module(class_name, module_path, force_reload=False):
""" """
Import a module on the cache directory for modules and extract a class from it. Import a module on the cache directory for modules and extract a class from it.
""" """
module_path = module_path.replace(os.path.sep, ".") name = os.path.normpath(module_path)
module = importlib.import_module(module_path) if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
if class_name is None: if class_name is None:
return find_pipeline_class(module) return find_pipeline_class(module)
return getattr(module, class_name) return getattr(module, class_name)
@@ -203,6 +323,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
is_modular: bool = False,
): ):
""" """
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -257,7 +378,7 @@ def get_cached_module_file(
if os.path.isfile(module_file_or_url): if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url resolved_module_file = module_file_or_url
submodule = "local" submodule = "local"
elif pretrained_model_name_or_path.count("/") == 0: elif pretrained_model_name_or_path.count("/") == 0 and not is_modular:
available_versions = get_diffusers_versions() available_versions = get_diffusers_versions()
# cut ".dev0" # cut ".dev0"
latest_version = "v" + ".".join(__version__.split(".")[:3]) latest_version = "v" + ".".join(__version__.split(".")[:3])
@@ -297,6 +418,24 @@ def get_cached_module_file(
except EnvironmentError: except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise raise
elif is_modular:
try:
# Load from URL or cache if already cached
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
else: else:
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
@@ -381,6 +520,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
is_modular: bool = False,
**kwargs, **kwargs,
): ):
""" """
@@ -453,5 +593,7 @@ def get_class_from_dynamic_module(
token=token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
is_modular=is_modular,
) )
return get_class_in_module(class_name, final_module.replace(".py", "")) __import__("ipdb").set_trace()
return get_class_in_module(class_name, final_module)